fix example

This commit is contained in:
Feng Shijie
2025-07-10 07:01:39 +00:00
parent 4d80c56d07
commit 39bcd46599

View File

@@ -17,21 +17,17 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t)
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if constexpr(GemmConfig<T>::N_Warp_Tile == 32)
{
ck_tile::HostTensor<T> t_view(
{n_ / 32, 32, k_ / GemmConfig<T>::K_Warp_Tile, 2, GemmConfig<T>::K_Warp_Tile / 2});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
else
{
static_assert(GemmConfig<T>::N_Warp_Tile == 16);
ck_tile::HostTensor<T> t_view(
{n_ / 16, 16, k_ / GemmConfig<T>::K_Warp_Tile, 4, GemmConfig<T>::K_Warp_Tile / 4});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
constexpr int N_Warp_Tile = GemmConfig<T>::N_Warp_Tile;
constexpr int N_Warp = GemmConfig<T>::N_Warp;
constexpr int KPerLane = GemmConfig<T>::K_Warp_Tile / (64 / N_Warp_Tile);
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
N_Warp_Tile,
k_ / (64 * KPerLane / N_Warp_Tile),
64 / N_Warp_Tile,
KPerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
@@ -414,7 +410,7 @@ int run_contiguous_grouped_flatmm_example_with_layouts(
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout))));
std::vector<ck_tile::index_t> m_indices(std::size_t(M), -1);
std::vector<ck_tile::index_t> m_indices(M);
int indices_fill_start = 0;
for(int i = 0; i < group_count; ++i)
{
@@ -428,9 +424,11 @@ int run_contiguous_grouped_flatmm_example_with_layouts(
}
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-4.f, 4.f}(b_k_n_tensor);
c_m_n_tensor.SetZero();
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
constexpr int N_Warp_Tile = GemmConfig<BDataType>::N_Warp_Tile;
assert(N % N_Warp_Tile == 0 &&
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<BDataType>(b_k_n_tensor);
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
@@ -444,6 +442,9 @@ int run_contiguous_grouped_flatmm_example_with_layouts(
ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t));
m_indices_dev_buf.ToDevice(m_indices.data());
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
ck_tile::ContiguousGroupedFlatmmHostArgs kernal_args{
static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
M,