Skip to content

Commit a5b9f54

Browse files
committed
Auto merge of #3214 - eduardosm:move-x86-code, r=RalfJung
Move some x86 intrinsics code to helper functions in `shims::x86` To make them reusable for intrinsics of other x86 features. Splitted from #3192
2 parents 33fb35e + 44bf5fc commit a5b9f54

File tree

3 files changed

+304
-265
lines changed

3 files changed

+304
-265
lines changed

src/shims/x86/mod.rs

Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use rand::Rng as _;
2+
3+
use rustc_apfloat::{ieee::Single, Float as _};
14
use rustc_middle::{mir, ty};
25
use rustc_span::Symbol;
36
use rustc_target::abi::Size;
@@ -331,6 +334,210 @@ fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
331334
Ok(())
332335
}
333336

337+
#[derive(Copy, Clone)]
338+
enum FloatUnaryOp {
339+
/// sqrt(x)
340+
///
341+
/// <https://www.felixcloutier.com/x86/sqrtss>
342+
/// <https://www.felixcloutier.com/x86/sqrtps>
343+
Sqrt,
344+
/// Approximation of 1/x
345+
///
346+
/// <https://www.felixcloutier.com/x86/rcpss>
347+
/// <https://www.felixcloutier.com/x86/rcpps>
348+
Rcp,
349+
/// Approximation of 1/sqrt(x)
350+
///
351+
/// <https://www.felixcloutier.com/x86/rsqrtss>
352+
/// <https://www.felixcloutier.com/x86/rsqrtps>
353+
Rsqrt,
354+
}
355+
356+
/// Performs `which` scalar operation on `op` and returns the result.
357+
#[allow(clippy::arithmetic_side_effects)] // floating point operations without side effects
358+
fn unary_op_f32<'tcx>(
359+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
360+
which: FloatUnaryOp,
361+
op: &ImmTy<'tcx, Provenance>,
362+
) -> InterpResult<'tcx, Scalar<Provenance>> {
363+
match which {
364+
FloatUnaryOp::Sqrt => {
365+
let op = op.to_scalar();
366+
// FIXME using host floats
367+
Ok(Scalar::from_u32(f32::from_bits(op.to_u32()?).sqrt().to_bits()))
368+
}
369+
FloatUnaryOp::Rcp => {
370+
let op = op.to_scalar().to_f32()?;
371+
let div = (Single::from_u128(1).value / op).value;
372+
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
373+
// inaccuracy of RCP.
374+
let res = apply_random_float_error(this, div, -12);
375+
Ok(Scalar::from_f32(res))
376+
}
377+
FloatUnaryOp::Rsqrt => {
378+
let op = op.to_scalar().to_u32()?;
379+
// FIXME using host floats
380+
let sqrt = Single::from_bits(f32::from_bits(op).sqrt().to_bits().into());
381+
let rsqrt = (Single::from_u128(1).value / sqrt).value;
382+
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
383+
// inaccuracy of RSQRT.
384+
let res = apply_random_float_error(this, rsqrt, -12);
385+
Ok(Scalar::from_f32(res))
386+
}
387+
}
388+
}
389+
390+
/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
391+
#[allow(clippy::arithmetic_side_effects)] // floating point arithmetic cannot panic
392+
fn apply_random_float_error<F: rustc_apfloat::Float>(
393+
this: &mut crate::MiriInterpCx<'_, '_>,
394+
val: F,
395+
err_scale: i32,
396+
) -> F {
397+
let rng = this.machine.rng.get_mut();
398+
// generates rand(0, 2^64) * 2^(scale - 64) = rand(0, 1) * 2^scale
399+
let err =
400+
F::from_u128(rng.gen::<u64>().into()).value.scalbn(err_scale.checked_sub(64).unwrap());
401+
// give it a random sign
402+
let err = if rng.gen::<bool>() { -err } else { err };
403+
// multiple the value with (1+err)
404+
(val * (F::from_u128(1).value + err).value).value
405+
}
406+
407+
/// Performs `which` operation on the first component of `op` and copies
408+
/// the other components. The result is stored in `dest`.
409+
fn unary_op_ss<'tcx>(
410+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
411+
which: FloatUnaryOp,
412+
op: &OpTy<'tcx, Provenance>,
413+
dest: &PlaceTy<'tcx, Provenance>,
414+
) -> InterpResult<'tcx, ()> {
415+
let (op, op_len) = this.operand_to_simd(op)?;
416+
let (dest, dest_len) = this.place_to_simd(dest)?;
417+
418+
assert_eq!(dest_len, op_len);
419+
420+
let res0 = unary_op_f32(this, which, &this.read_immediate(&this.project_index(&op, 0)?)?)?;
421+
this.write_scalar(res0, &this.project_index(&dest, 0)?)?;
422+
423+
for i in 1..dest_len {
424+
this.copy_op(
425+
&this.project_index(&op, i)?,
426+
&this.project_index(&dest, i)?,
427+
/*allow_transmute*/ false,
428+
)?;
429+
}
430+
431+
Ok(())
432+
}
433+
434+
/// Performs `which` operation on each component of `op`, storing the
435+
/// result is stored in `dest`.
436+
fn unary_op_ps<'tcx>(
437+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
438+
which: FloatUnaryOp,
439+
op: &OpTy<'tcx, Provenance>,
440+
dest: &PlaceTy<'tcx, Provenance>,
441+
) -> InterpResult<'tcx, ()> {
442+
let (op, op_len) = this.operand_to_simd(op)?;
443+
let (dest, dest_len) = this.place_to_simd(dest)?;
444+
445+
assert_eq!(dest_len, op_len);
446+
447+
for i in 0..dest_len {
448+
let op = this.read_immediate(&this.project_index(&op, i)?)?;
449+
let dest = this.project_index(&dest, i)?;
450+
451+
let res = unary_op_f32(this, which, &op)?;
452+
this.write_scalar(res, &dest)?;
453+
}
454+
455+
Ok(())
456+
}
457+
458+
// Rounds the first element of `right` according to `rounding`
459+
// and copies the remaining elements from `left`.
460+
fn round_first<'tcx, F: rustc_apfloat::Float>(
461+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
462+
left: &OpTy<'tcx, Provenance>,
463+
right: &OpTy<'tcx, Provenance>,
464+
rounding: &OpTy<'tcx, Provenance>,
465+
dest: &PlaceTy<'tcx, Provenance>,
466+
) -> InterpResult<'tcx, ()> {
467+
let (left, left_len) = this.operand_to_simd(left)?;
468+
let (right, right_len) = this.operand_to_simd(right)?;
469+
let (dest, dest_len) = this.place_to_simd(dest)?;
470+
471+
assert_eq!(dest_len, left_len);
472+
assert_eq!(dest_len, right_len);
473+
474+
let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
475+
476+
let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?;
477+
let res = op0.round_to_integral(rounding).value;
478+
this.write_scalar(
479+
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
480+
&this.project_index(&dest, 0)?,
481+
)?;
482+
483+
for i in 1..dest_len {
484+
this.copy_op(
485+
&this.project_index(&left, i)?,
486+
&this.project_index(&dest, i)?,
487+
/*allow_transmute*/ false,
488+
)?;
489+
}
490+
491+
Ok(())
492+
}
493+
494+
// Rounds all elements of `op` according to `rounding`.
495+
fn round_all<'tcx, F: rustc_apfloat::Float>(
496+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
497+
op: &OpTy<'tcx, Provenance>,
498+
rounding: &OpTy<'tcx, Provenance>,
499+
dest: &PlaceTy<'tcx, Provenance>,
500+
) -> InterpResult<'tcx, ()> {
501+
let (op, op_len) = this.operand_to_simd(op)?;
502+
let (dest, dest_len) = this.place_to_simd(dest)?;
503+
504+
assert_eq!(dest_len, op_len);
505+
506+
let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?;
507+
508+
for i in 0..dest_len {
509+
let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?;
510+
let res = op.round_to_integral(rounding).value;
511+
this.write_scalar(
512+
Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
513+
&this.project_index(&dest, i)?,
514+
)?;
515+
}
516+
517+
Ok(())
518+
}
519+
520+
/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
521+
/// `round.{ss,sd,ps,pd}` intrinsics.
522+
fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
523+
// The fourth bit of `rounding` only affects the SSE status
524+
// register, which cannot be accessed from Miri (or from Rust,
525+
// for that matter), so we can ignore it.
526+
match rounding & !0b1000 {
527+
// When the third bit is 0, the rounding mode is determined by the
528+
// first two bits.
529+
0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven),
530+
0b001 => Ok(rustc_apfloat::Round::TowardNegative),
531+
0b010 => Ok(rustc_apfloat::Round::TowardPositive),
532+
0b011 => Ok(rustc_apfloat::Round::TowardZero),
533+
// When the third bit is 1, the rounding mode is determined by the
534+
// SSE status register. Since we do not support modifying it from
535+
// Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
536+
0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven),
537+
rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"),
538+
}
539+
}
540+
334541
/// Converts each element of `op` from floating point to signed integer.
335542
///
336543
/// When the input value is NaN or out of range, fall back to minimum value.
@@ -408,3 +615,81 @@ fn horizontal_bin_op<'tcx>(
408615

409616
Ok(())
410617
}
618+
619+
/// Conditionally multiplies the packed floating-point elements in
620+
/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
621+
/// products (up to 4), and conditionally stores the sum in `dest` using
622+
/// the low 4 bits of `imm`.
623+
fn conditional_dot_product<'tcx>(
624+
this: &mut crate::MiriInterpCx<'_, 'tcx>,
625+
left: &OpTy<'tcx, Provenance>,
626+
right: &OpTy<'tcx, Provenance>,
627+
imm: &OpTy<'tcx, Provenance>,
628+
dest: &PlaceTy<'tcx, Provenance>,
629+
) -> InterpResult<'tcx, ()> {
630+
let (left, left_len) = this.operand_to_simd(left)?;
631+
let (right, right_len) = this.operand_to_simd(right)?;
632+
let (dest, dest_len) = this.place_to_simd(dest)?;
633+
634+
assert_eq!(left_len, right_len);
635+
assert!(dest_len <= 4);
636+
637+
let imm = this.read_scalar(imm)?.to_u8()?;
638+
639+
let element_layout = left.layout.field(this, 0);
640+
641+
// Calculate dot product
642+
// Elements are floating point numbers, but we can use `from_int`
643+
// because the representation of 0.0 is all zero bits.
644+
let mut sum = ImmTy::from_int(0u8, element_layout);
645+
for i in 0..left_len {
646+
if imm & (1 << i.checked_add(4).unwrap()) != 0 {
647+
let left = this.read_immediate(&this.project_index(&left, i)?)?;
648+
let right = this.read_immediate(&this.project_index(&right, i)?)?;
649+
650+
let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?;
651+
sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?;
652+
}
653+
}
654+
655+
// Write to destination (conditioned to imm)
656+
for i in 0..dest_len {
657+
let dest = this.project_index(&dest, i)?;
658+
659+
if imm & (1 << i) != 0 {
660+
this.write_immediate(*sum, &dest)?;
661+
} else {
662+
this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
663+
}
664+
}
665+
666+
Ok(())
667+
}
668+
669+
/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`.
670+
fn bin_op_folded<'tcx, T>(
671+
this: &crate::MiriInterpCx<'_, 'tcx>,
672+
lhs: &OpTy<'tcx, Provenance>,
673+
rhs: &OpTy<'tcx, Provenance>,
674+
init: T,
675+
mut f: impl FnMut(T, ImmTy<'tcx, Provenance>, ImmTy<'tcx, Provenance>) -> InterpResult<'tcx, T>,
676+
) -> InterpResult<'tcx, T> {
677+
assert_eq!(lhs.layout, rhs.layout);
678+
679+
let (lhs, lhs_len) = this.operand_to_simd(lhs)?;
680+
let (rhs, rhs_len) = this.operand_to_simd(rhs)?;
681+
682+
assert_eq!(lhs_len, rhs_len);
683+
684+
let mut acc = init;
685+
for i in 0..lhs_len {
686+
let lhs = this.project_index(&lhs, i)?;
687+
let rhs = this.project_index(&rhs, i)?;
688+
689+
let lhs = this.read_immediate(&lhs)?;
690+
let rhs = this.read_immediate(&rhs)?;
691+
acc = f(acc, lhs, rhs)?;
692+
}
693+
694+
Ok(acc)
695+
}

0 commit comments

Comments
 (0)