Update include path to break the remod's cyclic dep issue (#2978)

* Update include path to break the cyclic dep issue

* Use ck_tile::permute_vectors_i4x4_b in tile engine

---------

Co-authored-by: Damien Lejeune <damien.lejeune@amd.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
damien-lejeune
2025-10-13 13:24:47 +02:00
committed by GitHub
parent e9f0cc83a8
commit 46c10c316d
25 changed files with 51 additions and 61 deletions

View File

@@ -74,58 +74,6 @@ constexpr auto is_row_major(Layout)
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Permutation function for pk_int4_t
template <typename Tensor>
void permute_vectors_i4x4_b(Tensor& tensor)
{
const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1);
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int8_t input[8];
for(int k = 0; k < 4; k++)
{
int8_t i4x2 = tensor(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int8_t hi = input[2];
int8_t lo = input[0];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 0, i) = i4x2;
}
{
int8_t hi = input[6];
int8_t lo = input[4];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 2, i) = i4x2;
}
{
int8_t hi = input[3];
int8_t lo = input[1];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 4, i) = i4x2;
}
{
int8_t hi = input[7];
int8_t lo = input[5];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 6, i) = i4x2;
}
}
}
}
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{

View File

@@ -96,7 +96,7 @@ class GemmProfiler
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
permute_vectors_i4x4_b(b_k_n_dev);
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else