mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
fix build error (#3195)
Co-authored-by: root <root@hjbog-srdc-39.amd.com>
[ROCm/composable_kernel commit: 4d629cd2b0]
This commit is contained in:
@@ -47,7 +47,7 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
auto shuffle_b_v0(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
|
||||
@@ -103,7 +103,7 @@ int run_flatmm_example_with_layouts(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
return shuffle_b<FlatmmConfig>(b_origin_host);
|
||||
return shuffle_b_v0<FlatmmConfig>(b_origin_host);
|
||||
}
|
||||
}();
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -662,17 +662,21 @@ struct FlatmmKernel
|
||||
|
||||
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_m_ptr.ptr,
|
||||
make_tuple(
|
||||
kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA),
|
||||
make_tuple(kargs.M / ScaleGranularityM,
|
||||
ScaleGranularityKA == 0
|
||||
? 1
|
||||
: splitk_batch_offset.splitted_k /
|
||||
(ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)),
|
||||
make_tuple(scale_stride_m, 0),
|
||||
number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
|
||||
number<1>{});
|
||||
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
kargs.scale_n_ptr.ptr,
|
||||
make_tuple(
|
||||
ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB),
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(ScaleGranularityKB == 0
|
||||
? 1
|
||||
: (splitk_batch_offset.splitted_k /
|
||||
(ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)),
|
||||
kargs.N / ScaleGranularityN),
|
||||
make_tuple(0, scale_stride_n),
|
||||
number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
|
||||
number<1>{});
|
||||
|
||||
Reference in New Issue
Block a user