Skip to content

Commit 44bf5fc

Browse files
committed
Move implementation of SSE4.1 dpps/dppd to helper function
1 parent b1fcba4 commit 44bf5fc

File tree

2 files changed

+52
-37
lines changed

2 files changed

+52
-37
lines changed

src/shims/x86/mod.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,56 @@ fn horizontal_bin_op<'tcx>(
616616
Ok(())
617617
}
618618

619+
/// Conditionally multiplies the packed floating-point elements in
620+
/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
621+
/// products (up to 4), and conditionally stores the sum in `dest` using
622+
/// the low 4 bits of `imm`.
623+
fn conditional_dot_product<'tcx>(
624+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
625+
left: &OpTy<'tcx, Provenance>,
626+
right: &OpTy<'tcx, Provenance>,
627+
imm: &OpTy<'tcx, Provenance>,
628+
dest: &PlaceTy<'tcx, Provenance>,
629+
) -> InterpResult<'tcx, ()> {
630+
let (left, left_len) = this.operand_to_simd(left)?;
631+
let (right, right_len) = this.operand_to_simd(right)?;
632+
let (dest, dest_len) = this.place_to_simd(dest)?;
633+
634+
assert_eq!(left_len, right_len);
635+
assert!(dest_len <= 4);
636+
637+
let imm = this.read_scalar(imm)?.to_u8()?;
638+
639+
let element_layout = left.layout.field(this, 0);
640+
641+
// Calculate dot product
642+
// Elements are floating point numbers, but we can use `from_int`
643+
// because the representation of 0.0 is all zero bits.
644+
let mut sum = ImmTy::from_int(0u8, element_layout);
645+
for i in 0..left_len {
646+
if imm & (1 << i.checked_add(4).unwrap()) != 0 {
647+
let left = this.read_immediate(&this.project_index(&left, i)?)?;
648+
let right = this.read_immediate(&this.project_index(&right, i)?)?;
649+
650+
let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
651+
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
652+
}
653+
}
654+
655+
// Write to destination (conditioned to imm)
656+
for i in 0..dest_len {
657+
let dest = this.project_index(&dest, i)?;
658+
659+
if imm & (1 << i) != 0 {
660+
this.write_immediate(*sum, &dest)?;
661+
} else {
662+
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
663+
}
664+
}
665+
666+
Ok(())
667+
}
668+
619669
/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
620670
fn bin_op_folded<'tcx, T>(
621671
this: &crate::MiriInterpCx<'_, 'tcx>,

src/shims/x86/sse41.rs

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
use rustc_middle::mir;
21
use rustc_span::Symbol;
32
use rustc_target::spec::abi::Abi;
43

5-
use super::{bin_op_folded, round_all, round_first};
4+
use super::{bin_op_folded, conditional_dot_product, round_all, round_first};
65
use crate::*;
76
use shims::foreign_items::EmulateForeignItemResult;
87

@@ -104,41 +103,7 @@ pub(super) trait EvalContextExt<'mir, 'tcx: 'mir>:
104103
let [left, right, imm] =
105104
this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
106105

107-
let (left, left_len) = this.operand_to_simd(left)?;
108-
let (right, right_len) = this.operand_to_simd(right)?;
109-
let (dest, dest_len) = this.place_to_simd(dest)?;
110-
111-
assert_eq!(left_len, right_len);
112-
assert!(dest_len <= 4);
113-
114-
let imm = this.read_scalar(imm)?.to_u8()?;
115-
116-
let element_layout = left.layout.field(this, 0);
117-
118-
// Calculate dot product
119-
// Elements are floating point numbers, but we can use `from_int`
120-
// because the representation of 0.0 is all zero bits.
121-
let mut sum = ImmTy::from_int(0u8, element_layout);
122-
for i in 0..left_len {
123-
if imm & (1 << i.checked_add(4).unwrap()) != 0 {
124-
let left = this.read_immediate(&this.project_index(&left, i)?)?;
125-
let right = this.read_immediate(&this.project_index(&right, i)?)?;
126-
127-
let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
128-
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
129-
}
130-
}
131-
132-
// Write to destination (conditioned to imm)
133-
for i in 0..dest_len {
134-
let dest = this.project_index(&dest, i)?;
135-
136-
if imm & (1 << i) != 0 {
137-
this.write_immediate(*sum, &dest)?;
138-
} else {
139-
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
140-
}
141-
}
106+
conditional_dot_product(this, left, right, imm, dest)?;
142107
}
143108
// Used to implement the _mm_floor_ss, _mm_ceil_ss and _mm_round_ss
144109
// functions. Rounds the first element of `right` according to `rounding`

0 commit comments

Comments
 (0)