[CK][CK Tile] Improve access for merged groups and remove modulo from xor (#5334)

## Motivation

[CK][CK Tile] Improve access for merged groups and remove modulo from
xor

## Technical Details

- add template parameter to xor if modulo is needed. We don't need
modulo for merged groups
- use access by m for merged groups for a tensor
- 
## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

passed locally

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Bartłomiej Kocot
2026-03-20 16:45:45 +01:00
committed by GitHub
parent e785241250
commit 2e22c67ce6
5 changed files with 96 additions and 42 deletions

View File

@@ -1298,7 +1298,7 @@ CK_TILE_HOST_DEVICE static void print(const modulo<Modulus, UpLength>& m)
}
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths>
template <typename LowLengths, bool ApplyModulo = true>
struct xor_t : public base_transform<2, 2>
{
static constexpr auto type_enum = coord_transform_enum::xor_t;
@@ -1330,8 +1330,15 @@ struct xor_t : public base_transform<2, 2>
idx_low(number<0>{}) = idx_up[number<0>{}];
idx_low(number<1>{}) =
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
if constexpr(ApplyModulo)
{
idx_low(number<1>{}) =
idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
}
else
{
idx_low(number<1>{}) = idx_up[number<1>{}] ^ (idx_up[number<0>{}]);
}
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
@@ -1382,8 +1389,8 @@ struct xor_t : public base_transform<2, 2>
}
};
template <typename LowLengths>
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths>& x)
template <typename LowLengths, bool ApplyModulo = true>
CK_TILE_HOST_DEVICE static void print(const xor_t<LowLengths, ApplyModulo>& x)
{
printf("xor_t{");
printf("up_lengths_: ");
@@ -1737,10 +1744,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths>
template <typename LowLengths, bool ApplyModulo = true>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
{
return xor_t<LowLengths>{low_lengths};
return xor_t<LowLengths, ApplyModulo>{low_lengths};
}
template <typename LowLength, typename OffsetLength>