mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] ABQuant New Preshuffle (#3638)
* Refactor * Gemm quant improvement * Change preshuffle * Fix * Fix grouped gemm ut * Fix --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -69,7 +69,7 @@ auto shuffle_bq(const ck_tile::HostTensor<T>* t, int block_bq_k)
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t, GemmConfig)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
@@ -79,36 +79,40 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
else if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
k_ / gemmConfig.K_Warp_Tile,
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
gemmConfig.K_Warp_Tile / divisor});
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
|
||||
constexpr int ItemsPerAccess =
|
||||
std::min(16 / static_cast<int>(sizeof(T)), GemmConfig::K_Warp_Tile / KLane);
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / ItemsPerAccess,
|
||||
ItemsPerAccess});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
|
||||
Reference in New Issue
Block a user