Merge some updates for ck_tile headers (#3342)

* fix some issues from internal branch

* update cshuffle_epilogue

* update cshuffle_epilogue

* update cshuffle

* update warp_gemm
This commit is contained in:
joyeamd
2026-01-06 15:39:00 +08:00
committed by GitHub
parent 2b563ad048
commit b78563b3d3
14 changed files with 205 additions and 119 deletions

View File

@@ -1124,8 +1124,14 @@ CK_TILE_DEVICE static constexpr auto get_device_arch()
{
// FIXME(0): on all devices except gfx11 it returns gfx12_t
// FIXME(1): during the host compilation pass it returns gfx12_t
#if defined(__gfx11__)
#if defined(__gfx103__)
return gfx103_t{};
#elif defined(__gfx11__)
return gfx11_t{};
#elif defined(__gfx950__)
return gfx950_t{};
#elif defined(__gfx9__)
return gfx9_t{};
#else
return gfx12_t{};
#endif
@@ -1146,26 +1152,10 @@ CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }
CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; }
CK_TILE_DEVICE static constexpr auto arch_tag_dispatch()
{
#if defined(__gfx103__)
return gfx103_t{};
#elif defined(__gfx11__)
return gfx11_t{};
#elif defined(__gfx12__)
return gfx12_t{};
#elif defined(__gfx950__)
return gfx950_t{};
#elif defined(__gfx9__)
return gfx9_t{};
#else
return gfx_invalid_t{};
#endif
}
} // namespace detail
CK_TILE_DEVICE static constexpr auto get_n_lds_banks()
{
return detail::get_n_lds_banks(detail::arch_tag_dispatch());
return detail::get_n_lds_banks(get_device_arch());
}
enum LLVMSchedGroupMask : int32_t

View File

@@ -34,46 +34,23 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
constexpr auto y_out_desc = OutTensor::get_tile_distribution().get_ys_to_d_descriptor();
// y_dim_out_to_in
// For swapped Hs tile case I need only get_rh_minor_to_y
// since rh_major are already swapped due to swapped Hs.
constexpr auto get_rh_minor_to_y = [](auto dstr_tensor) {
using DstrEncode = typename decltype(dstr_tensor.get_tile_distribution())::DstrEncode;
map<index_t, index_t> rh_minor_to_y_;
static_for<0, DstrEncode::NDimY, 1>{}([&](auto i) {
constexpr index_t rh_minor = DstrEncode::ys_to_rhs_minor_[i];
rh_minor_to_y_(rh_minor) = i;
});
return rh_minor_to_y_;
};
// In swapped Hs case <Y,X> -> <X,Y> tile
// we have same rh_major, but reversed rh_minor!
constexpr auto rh_minor_to_y_in = get_rh_minor_to_y(InTensor{});
constexpr auto rh_minor_to_y_out = get_rh_minor_to_y(OutTensor{});
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
// Is this really needed?? Should we have simple reverse here??
constexpr auto y_dim_out_to_in = [&] {
map<index_t, index_t> y_dim_out_to_in_;
for(const auto& [rh_minor, y_out] : rh_minor_to_y_out)
{
y_dim_out_to_in_(y_out) = rh_minor_to_y_in[rh_minor];
}
static_for<0, NDimY, 1>{}([&](auto i) { y_dim_out_to_in_(i) = NDimY - 1 - i; });
return y_dim_out_to_in_;
}();
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
constexpr index_t y_dim_vec_in = NDimY - 1;
constexpr index_t y_dim_vec_out = y_dim_out_to_in[NDimY - 1];
constexpr index_t y_dim_vec_out = 0;
// vector lengths
constexpr index_t vec_length_in = y_lengths[y_dim_vec_in];