fix clang format

This commit is contained in:
illsilin
2025-03-28 11:52:40 -07:00
parent 656a3657cb
commit d3fb5a9b8d
8 changed files with 191 additions and 182 deletions

View File

@@ -255,7 +255,7 @@ int main(int argc, char* argv[])
// max_token_id.mData[0] = valid_size;
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 3, 3, 3};
//max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
// max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8};
// int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
@@ -419,20 +419,19 @@ int main(int argc, char* argv[])
Tensor<float> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
B0DataType,
D0DataType,
D1DataType,
D2DataType,
float,
AccDataType,
PassThrough,
PassThrough,
CDEElementOp>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm2<A0DataType,
B0DataType,
D0DataType,
D1DataType,
D2DataType,
float,
AccDataType,
PassThrough,
PassThrough,
CDEElementOp>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,

View File

@@ -58,45 +58,45 @@ template <index_t BlockSize,
index_t KPack
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack> : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
@@ -241,7 +241,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<BlockGemmPi
CThreadBuffer& c_thread_buf,
CThreadBuffer& c_thread_buf_up,
index_t num_loop) const
{
ignore = b_block_buf;
__builtin_amdgcn_sched_barrier(0);
@@ -258,7 +258,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<BlockGemmPi
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs_up;
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}>
b_thread_dequant_bufs_up;
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
@@ -268,10 +269,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<BlockGemmPi
b_block_origin_idx,
b_thread_bufs(I0));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0));
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
@@ -299,11 +300,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<BlockGemmPi
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0),
@@ -330,10 +331,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<BlockGemmPi
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(local_read_buf));
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(local_read_buf));
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
@@ -358,9 +359,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<BlockGemmPi
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_dequant_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_dequant_bufs_up
[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
@@ -428,10 +429,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<BlockGemmPi
b_thread_bufs(I1));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);

View File

@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now...
>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
@@ -141,8 +141,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::MWaves;
using Base::c_thread_desc_;
using Base::MWaves;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
@@ -186,8 +186,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
{
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2;
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
constexpr auto num_buffer_load_inst_b =
HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2;
constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2;
// B global
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
@@ -276,10 +277,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
b_block_origin_idx,
b_thread_bufs(I0));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0));
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
@@ -327,10 +328,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(local_read_buf));
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(local_read_buf));
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
@@ -354,8 +355,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
b_thread_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
@@ -368,7 +369,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
@@ -410,10 +411,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
b_thread_bufs(I1));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
@@ -445,7 +446,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec_up.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
@@ -537,7 +538,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
@@ -567,7 +568,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlockGemmPipelineSch
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
};
} // namespace ck

View File

@@ -43,53 +43,58 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
{
if constexpr(std::is_same<ADataType, BDataType>::value)
{
if constexpr(GUFusion) {
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
} else {
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
else
{
if constexpr(GUFusion) {
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1<
BlkGemmPipeSche,
BlockSize,
@@ -112,7 +117,8 @@ constexpr auto BlockGemmBPreshufflePipeline_Selector()
NRepeat,
KPack>{};
}
else {
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1<
BlkGemmPipeSche,
BlockSize,

View File

@@ -335,9 +335,7 @@ struct BlockwiseGemmXdlops_pipeline_base
return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCThreadDesc() {
return c_thread_desc_;
}
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;

View File

@@ -397,7 +397,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
IndexType scatter_offset = 0;
if constexpr(OutputScatter)
{
@@ -408,7 +408,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);

View File

@@ -25,7 +25,7 @@ template <AddressSpaceEnum BufferAddressSpace,
typename ElementSpaceSize,
bool InvalidElementUseNumericalZeroValue,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence,
typename IndexType = index_t>
typename IndexType = index_t>
struct DynamicBuffer
{
using type = T;
@@ -380,13 +380,14 @@ struct DynamicBuffer
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, int32_t>;
bool constexpr use_amd_buffer_addressing =
sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
sizeof(IndexType) <= sizeof(int32_t) && (
is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0));
sizeof(IndexType) <= sizeof(int32_t) &&
(is_same_v<remove_cvref_t<scalar_t>, float> ||
(is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0));
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
@@ -424,8 +425,9 @@ struct DynamicBuffer
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, double>;
using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
bool constexpr use_amd_buffer_addressing =
sizeof(IndexType) <= sizeof(int32_t) && is_same_v<remove_cvref_t<scalar_t>, double>;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
@@ -462,7 +464,8 @@ template <AddressSpaceEnum BufferAddressSpace,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence,
typename T,
typename ElementSpaceSize>
__host__ __device__ constexpr auto make_long_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
__host__ __device__ constexpr auto make_long_dynamic_buffer(T* p,
ElementSpaceSize element_space_size)
{
return DynamicBuffer<BufferAddressSpace, T, ElementSpaceSize, true, coherence, long_index_t>{
p, element_space_size};

View File

@@ -113,7 +113,7 @@ struct ReferenceMoeGemm : public device::BaseOperator
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_a = i4_to_f32_gfx9(i4);
#else
v_a = i4 - 8;
v_a = i4 - 8;
#endif
}
else
@@ -123,23 +123,25 @@ struct ReferenceMoeGemm : public device::BaseOperator
// same for B matrix
if constexpr(is_same_v<BDataType, pk_i4_t>)
{
uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
uint8_t i4x2_up = arg.b_e_n_k_(e, k, n + full_n).data;
uint8_t i4 = 0;
uint8_t i4 = 0;
uint8_t i4_up = 0;
if(k % 2 == 1) {
i4 = (i4x2 >> 0) & 0xf;
if(k % 2 == 1)
{
i4 = (i4x2 >> 0) & 0xf;
i4_up = (i4x2_up >> 0) & 0xf;
}
else {
i4 = (i4x2 >> 4) & 0xf;
else
{
i4 = (i4x2 >> 4) & 0xf;
i4_up = (i4x2_up >> 4) & 0xf;
}
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_b = i4_to_f32_gfx9(i4);
v_b = i4_to_f32_gfx9(i4);
v_b_up = i4_to_f32_gfx9(i4_up);
#else
v_b = i4 - 8;
v_b = i4 - 8;
v_b_up = i4_up - 8;
#endif
}