@@ -616,6 +616,56 @@ fn horizontal_bin_op<'tcx>(
616
616
Ok ( ( ) )
617
617
}
618
618
619
+ /// Conditionally multiplies the packed floating-point elements in
620
+ /// `left` and `right` using the high 4 bits in `imm`, sums the calculated
621
+ /// products (up to 4), and conditionally stores the sum in `dest` using
622
+ /// the low 4 bits of `imm`.
623
+ fn conditional_dot_product < ' tcx > (
624
+ this : & mut crate :: MiriInterpCx < ' _ , ' tcx > ,
625
+ left : & OpTy < ' tcx , Provenance > ,
626
+ right : & OpTy < ' tcx , Provenance > ,
627
+ imm : & OpTy < ' tcx , Provenance > ,
628
+ dest : & PlaceTy < ' tcx , Provenance > ,
629
+ ) -> InterpResult < ' tcx , ( ) > {
630
+ let ( left, left_len) = this. operand_to_simd ( left) ?;
631
+ let ( right, right_len) = this. operand_to_simd ( right) ?;
632
+ let ( dest, dest_len) = this. place_to_simd ( dest) ?;
633
+
634
+ assert_eq ! ( left_len, right_len) ;
635
+ assert ! ( dest_len <= 4 ) ;
636
+
637
+ let imm = this. read_scalar ( imm) ?. to_u8 ( ) ?;
638
+
639
+ let element_layout = left. layout . field ( this, 0 ) ;
640
+
641
+ // Calculate dot product
642
+ // Elements are floating point numbers, but we can use `from_int`
643
+ // because the representation of 0.0 is all zero bits.
644
+ let mut sum = ImmTy :: from_int ( 0u8 , element_layout) ;
645
+ for i in 0 ..left_len {
646
+ if imm & ( 1 << i. checked_add ( 4 ) . unwrap ( ) ) != 0 {
647
+ let left = this. read_immediate ( & this. project_index ( & left, i) ?) ?;
648
+ let right = this. read_immediate ( & this. project_index ( & right, i) ?) ?;
649
+
650
+ let mul = this. wrapping_binary_op ( mir:: BinOp :: Mul , & left, & right) ?;
651
+ sum = this. wrapping_binary_op ( mir:: BinOp :: Add , & sum, & mul) ?;
652
+ }
653
+ }
654
+
655
+ // Write to destination (conditioned to imm)
656
+ for i in 0 ..dest_len {
657
+ let dest = this. project_index ( & dest, i) ?;
658
+
659
+ if imm & ( 1 << i) != 0 {
660
+ this. write_immediate ( * sum, & dest) ?;
661
+ } else {
662
+ this. write_scalar ( Scalar :: from_int ( 0u8 , element_layout. size ) , & dest) ?;
663
+ }
664
+ }
665
+
666
+ Ok ( ( ) )
667
+ }
668
+
619
669
/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
620
670
fn bin_op_folded < ' tcx , T > (
621
671
this : & crate :: MiriInterpCx < ' _ , ' tcx > ,
0 commit comments