mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
fix example
This commit is contained in:
@@ -17,21 +17,17 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if constexpr(GemmConfig<T>::N_Warp_Tile == 32)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / 32, 32, k_ / GemmConfig<T>::K_Warp_Tile, 2, GemmConfig<T>::K_Warp_Tile / 2});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(GemmConfig<T>::N_Warp_Tile == 16);
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / 16, 16, k_ / GemmConfig<T>::K_Warp_Tile, 4, GemmConfig<T>::K_Warp_Tile / 4});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
constexpr int N_Warp_Tile = GemmConfig<T>::N_Warp_Tile;
|
||||
constexpr int N_Warp = GemmConfig<T>::N_Warp;
|
||||
constexpr int KPerLane = GemmConfig<T>::K_Warp_Tile / (64 / N_Warp_Tile);
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
k_ / (64 * KPerLane / N_Warp_Tile),
|
||||
64 / N_Warp_Tile,
|
||||
KPerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
@@ -414,7 +410,7 @@ int run_contiguous_grouped_flatmm_example_with_layouts(
|
||||
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout))));
|
||||
|
||||
std::vector<ck_tile::index_t> m_indices(std::size_t(M), -1);
|
||||
std::vector<ck_tile::index_t> m_indices(M);
|
||||
int indices_fill_start = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
@@ -428,9 +424,11 @@ int run_contiguous_grouped_flatmm_example_with_layouts(
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_k_n_tensor);
|
||||
c_m_n_tensor.SetZero();
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
|
||||
constexpr int N_Warp_Tile = GemmConfig<BDataType>::N_Warp_Tile;
|
||||
assert(N % N_Warp_Tile == 0 &&
|
||||
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<BDataType>(b_k_n_tensor);
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
|
||||
@@ -444,6 +442,9 @@ int run_contiguous_grouped_flatmm_example_with_layouts(
|
||||
ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t));
|
||||
m_indices_dev_buf.ToDevice(m_indices.data());
|
||||
|
||||
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
|
||||
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
|
||||
|
||||
ck_tile::ContiguousGroupedFlatmmHostArgs kernal_args{
|
||||
static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
|
||||
Reference in New Issue
Block a user