|
| 1 | +use rand::Rng as _; |
| 2 | + |
| 3 | +use rustc_apfloat::{ieee::Single, Float as _}; |
1 | 4 | use rustc_middle::{mir, ty};
|
2 | 5 | use rustc_span::Symbol;
|
3 | 6 | use rustc_target::abi::Size;
|
@@ -331,6 +334,210 @@ fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
|
331 | 334 | Ok(())
|
332 | 335 | }
|
333 | 336 |
|
| 337 | +#[derive(Copy, Clone)] |
| 338 | +enum FloatUnaryOp { |
| 339 | + /// sqrt(x) |
| 340 | + /// |
| 341 | + /// <https://www.felixcloutier.com/x86/sqrtss> |
| 342 | + /// <https://www.felixcloutier.com/x86/sqrtps> |
| 343 | + Sqrt, |
| 344 | + /// Approximation of 1/x |
| 345 | + /// |
| 346 | + /// <https://www.felixcloutier.com/x86/rcpss> |
| 347 | + /// <https://www.felixcloutier.com/x86/rcpps> |
| 348 | + Rcp, |
| 349 | + /// Approximation of 1/sqrt(x) |
| 350 | + /// |
| 351 | + /// <https://www.felixcloutier.com/x86/rsqrtss> |
| 352 | + /// <https://www.felixcloutier.com/x86/rsqrtps> |
| 353 | + Rsqrt, |
| 354 | +} |
| 355 | + |
| 356 | +/// Performs `which` scalar operation on `op` and returns the result. |
| 357 | +#[allow(clippy::arithmetic_side_effects)] // floating point operations without side effects |
| 358 | +fn unary_op_f32<'tcx>( |
| 359 | + this: &mut crate::MiriInterpCx<'_, 'tcx>, |
| 360 | + which: FloatUnaryOp, |
| 361 | + op: &ImmTy<'tcx, Provenance>, |
| 362 | +) -> InterpResult<'tcx, Scalar<Provenance>> { |
| 363 | + match which { |
| 364 | + FloatUnaryOp::Sqrt => { |
| 365 | + let op = op.to_scalar(); |
| 366 | + // FIXME using host floats |
| 367 | + Ok(Scalar::from_u32(f32::from_bits(op.to_u32()?).sqrt().to_bits())) |
| 368 | + } |
| 369 | + FloatUnaryOp::Rcp => { |
| 370 | + let op = op.to_scalar().to_f32()?; |
| 371 | + let div = (Single::from_u128(1).value / op).value; |
| 372 | + // Apply a relative error with a magnitude on the order of 2^-12 to simulate the |
| 373 | + // inaccuracy of RCP. |
| 374 | + let res = apply_random_float_error(this, div, -12); |
| 375 | + Ok(Scalar::from_f32(res)) |
| 376 | + } |
| 377 | + FloatUnaryOp::Rsqrt => { |
| 378 | + let op = op.to_scalar().to_u32()?; |
| 379 | + // FIXME using host floats |
| 380 | + let sqrt = Single::from_bits(f32::from_bits(op).sqrt().to_bits().into()); |
| 381 | + let rsqrt = (Single::from_u128(1).value / sqrt).value; |
| 382 | + // Apply a relative error with a magnitude on the order of 2^-12 to simulate the |
| 383 | + // inaccuracy of RSQRT. |
| 384 | + let res = apply_random_float_error(this, rsqrt, -12); |
| 385 | + Ok(Scalar::from_f32(res)) |
| 386 | + } |
| 387 | + } |
| 388 | +} |
| 389 | + |
| 390 | +/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale). |
| 391 | +#[allow(clippy::arithmetic_side_effects)] // floating point arithmetic cannot panic |
| 392 | +fn apply_random_float_error<F: rustc_apfloat::Float>( |
| 393 | + this: &mut crate::MiriInterpCx<'_, '_>, |
| 394 | + val: F, |
| 395 | + err_scale: i32, |
| 396 | +) -> F { |
| 397 | + let rng = this.machine.rng.get_mut(); |
| 398 | + // generates rand(0, 2^64) * 2^(scale - 64) = rand(0, 1) * 2^scale |
| 399 | + let err = |
| 400 | + F::from_u128(rng.gen::<u64>().into()).value.scalbn(err_scale.checked_sub(64).unwrap()); |
| 401 | + // give it a random sign |
| 402 | + let err = if rng.gen::<bool>() { -err } else { err }; |
| 403 | + // multiple the value with (1+err) |
| 404 | + (val * (F::from_u128(1).value + err).value).value |
| 405 | +} |
| 406 | + |
| 407 | +/// Performs `which` operation on the first component of `op` and copies |
| 408 | +/// the other components. The result is stored in `dest`. |
| 409 | +fn unary_op_ss<'tcx>( |
| 410 | + this: &mut crate::MiriInterpCx<'_, 'tcx>, |
| 411 | + which: FloatUnaryOp, |
| 412 | + op: &OpTy<'tcx, Provenance>, |
| 413 | + dest: &PlaceTy<'tcx, Provenance>, |
| 414 | +) -> InterpResult<'tcx, ()> { |
| 415 | + let (op, op_len) = this.operand_to_simd(op)?; |
| 416 | + let (dest, dest_len) = this.place_to_simd(dest)?; |
| 417 | + |
| 418 | + assert_eq!(dest_len, op_len); |
| 419 | + |
| 420 | + let res0 = unary_op_f32(this, which, &this.read_immediate(&this.project_index(&op, 0)?)?)?; |
| 421 | + this.write_scalar(res0, &this.project_index(&dest, 0)?)?; |
| 422 | + |
| 423 | + for i in 1..dest_len { |
| 424 | + this.copy_op( |
| 425 | + &this.project_index(&op, i)?, |
| 426 | + &this.project_index(&dest, i)?, |
| 427 | + /*allow_transmute*/ false, |
| 428 | + )?; |
| 429 | + } |
| 430 | + |
| 431 | + Ok(()) |
| 432 | +} |
| 433 | + |
| 434 | +/// Performs `which` operation on each component of `op`, storing the |
| 435 | +/// result is stored in `dest`. |
| 436 | +fn unary_op_ps<'tcx>( |
| 437 | + this: &mut crate::MiriInterpCx<'_, 'tcx>, |
| 438 | + which: FloatUnaryOp, |
| 439 | + op: &OpTy<'tcx, Provenance>, |
| 440 | + dest: &PlaceTy<'tcx, Provenance>, |
| 441 | +) -> InterpResult<'tcx, ()> { |
| 442 | + let (op, op_len) = this.operand_to_simd(op)?; |
| 443 | + let (dest, dest_len) = this.place_to_simd(dest)?; |
| 444 | + |
| 445 | + assert_eq!(dest_len, op_len); |
| 446 | + |
| 447 | + for i in 0..dest_len { |
| 448 | + let op = this.read_immediate(&this.project_index(&op, i)?)?; |
| 449 | + let dest = this.project_index(&dest, i)?; |
| 450 | + |
| 451 | + let res = unary_op_f32(this, which, &op)?; |
| 452 | + this.write_scalar(res, &dest)?; |
| 453 | + } |
| 454 | + |
| 455 | + Ok(()) |
| 456 | +} |
| 457 | + |
| 458 | +// Rounds the first element of `right` according to `rounding` |
| 459 | +// and copies the remaining elements from `left`. |
| 460 | +fn round_first<'tcx, F: rustc_apfloat::Float>( |
| 461 | + this: &mut crate::MiriInterpCx<'_, 'tcx>, |
| 462 | + left: &OpTy<'tcx, Provenance>, |
| 463 | + right: &OpTy<'tcx, Provenance>, |
| 464 | + rounding: &OpTy<'tcx, Provenance>, |
| 465 | + dest: &PlaceTy<'tcx, Provenance>, |
| 466 | +) -> InterpResult<'tcx, ()> { |
| 467 | + let (left, left_len) = this.operand_to_simd(left)?; |
| 468 | + let (right, right_len) = this.operand_to_simd(right)?; |
| 469 | + let (dest, dest_len) = this.place_to_simd(dest)?; |
| 470 | + |
| 471 | + assert_eq!(dest_len, left_len); |
| 472 | + assert_eq!(dest_len, right_len); |
| 473 | + |
| 474 | + let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?; |
| 475 | + |
| 476 | + let op0: F = this.read_scalar(&this.project_index(&right, 0)?)?.to_float()?; |
| 477 | + let res = op0.round_to_integral(rounding).value; |
| 478 | + this.write_scalar( |
| 479 | + Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)), |
| 480 | + &this.project_index(&dest, 0)?, |
| 481 | + )?; |
| 482 | + |
| 483 | + for i in 1..dest_len { |
| 484 | + this.copy_op( |
| 485 | + &this.project_index(&left, i)?, |
| 486 | + &this.project_index(&dest, i)?, |
| 487 | + /*allow_transmute*/ false, |
| 488 | + )?; |
| 489 | + } |
| 490 | + |
| 491 | + Ok(()) |
| 492 | +} |
| 493 | + |
| 494 | +// Rounds all elements of `op` according to `rounding`. |
| 495 | +fn round_all<'tcx, F: rustc_apfloat::Float>( |
| 496 | + this: &mut crate::MiriInterpCx<'_, 'tcx>, |
| 497 | + op: &OpTy<'tcx, Provenance>, |
| 498 | + rounding: &OpTy<'tcx, Provenance>, |
| 499 | + dest: &PlaceTy<'tcx, Provenance>, |
| 500 | +) -> InterpResult<'tcx, ()> { |
| 501 | + let (op, op_len) = this.operand_to_simd(op)?; |
| 502 | + let (dest, dest_len) = this.place_to_simd(dest)?; |
| 503 | + |
| 504 | + assert_eq!(dest_len, op_len); |
| 505 | + |
| 506 | + let rounding = rounding_from_imm(this.read_scalar(rounding)?.to_i32()?)?; |
| 507 | + |
| 508 | + for i in 0..dest_len { |
| 509 | + let op: F = this.read_scalar(&this.project_index(&op, i)?)?.to_float()?; |
| 510 | + let res = op.round_to_integral(rounding).value; |
| 511 | + this.write_scalar( |
| 512 | + Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)), |
| 513 | + &this.project_index(&dest, i)?, |
| 514 | + )?; |
| 515 | + } |
| 516 | + |
| 517 | + Ok(()) |
| 518 | +} |
| 519 | + |
| 520 | +/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of |
| 521 | +/// `round.{ss,sd,ps,pd}` intrinsics. |
| 522 | +fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> { |
| 523 | + // The fourth bit of `rounding` only affects the SSE status |
| 524 | + // register, which cannot be accessed from Miri (or from Rust, |
| 525 | + // for that matter), so we can ignore it. |
| 526 | + match rounding & !0b1000 { |
| 527 | + // When the third bit is 0, the rounding mode is determined by the |
| 528 | + // first two bits. |
| 529 | + 0b000 => Ok(rustc_apfloat::Round::NearestTiesToEven), |
| 530 | + 0b001 => Ok(rustc_apfloat::Round::TowardNegative), |
| 531 | + 0b010 => Ok(rustc_apfloat::Round::TowardPositive), |
| 532 | + 0b011 => Ok(rustc_apfloat::Round::TowardZero), |
| 533 | + // When the third bit is 1, the rounding mode is determined by the |
| 534 | + // SSE status register. Since we do not support modifying it from |
| 535 | + // Miri (or Rust), we assume it to be at its default mode (round-to-nearest). |
| 536 | + 0b100..=0b111 => Ok(rustc_apfloat::Round::NearestTiesToEven), |
| 537 | + rounding => throw_unsup_format!("unsupported rounding mode 0x{rounding:02x}"), |
| 538 | + } |
| 539 | +} |
| 540 | + |
334 | 541 | /// Converts each element of `op` from floating point to signed integer.
|
335 | 542 | ///
|
336 | 543 | /// When the input value is NaN or out of range, fall back to minimum value.
|
@@ -408,3 +615,81 @@ fn horizontal_bin_op<'tcx>(
|
408 | 615 |
|
409 | 616 | Ok(())
|
410 | 617 | }
|
| 618 | + |
| 619 | +/// Conditionally multiplies the packed floating-point elements in |
| 620 | +/// `left` and `right` using the high 4 bits in `imm`, sums the calculated |
| 621 | +/// products (up to 4), and conditionally stores the sum in `dest` using |
| 622 | +/// the low 4 bits of `imm`. |
| 623 | +fn conditional_dot_product<'tcx>( |
| 624 | + this: &mut crate::MiriInterpCx<'_, 'tcx>, |
| 625 | + left: &OpTy<'tcx, Provenance>, |
| 626 | + right: &OpTy<'tcx, Provenance>, |
| 627 | + imm: &OpTy<'tcx, Provenance>, |
| 628 | + dest: &PlaceTy<'tcx, Provenance>, |
| 629 | +) -> InterpResult<'tcx, ()> { |
| 630 | + let (left, left_len) = this.operand_to_simd(left)?; |
| 631 | + let (right, right_len) = this.operand_to_simd(right)?; |
| 632 | + let (dest, dest_len) = this.place_to_simd(dest)?; |
| 633 | + |
| 634 | + assert_eq!(left_len, right_len); |
| 635 | + assert!(dest_len <= 4); |
| 636 | + |
| 637 | + let imm = this.read_scalar(imm)?.to_u8()?; |
| 638 | + |
| 639 | + let element_layout = left.layout.field(this, 0); |
| 640 | + |
| 641 | + // Calculate dot product |
| 642 | + // Elements are floating point numbers, but we can use `from_int` |
| 643 | + // because the representation of 0.0 is all zero bits. |
| 644 | + let mut sum = ImmTy::from_int(0u8, element_layout); |
| 645 | + for i in 0..left_len { |
| 646 | + if imm & (1 << i.checked_add(4).unwrap()) != 0 { |
| 647 | + let left = this.read_immediate(&this.project_index(&left, i)?)?; |
| 648 | + let right = this.read_immediate(&this.project_index(&right, i)?)?; |
| 649 | + |
| 650 | + let mul = this.wrapping_binary_op(mir::BinOp::Mul, &left, &right)?; |
| 651 | + sum = this.wrapping_binary_op(mir::BinOp::Add, &sum, &mul)?; |
| 652 | + } |
| 653 | + } |
| 654 | + |
| 655 | + // Write to destination (conditioned to imm) |
| 656 | + for i in 0..dest_len { |
| 657 | + let dest = this.project_index(&dest, i)?; |
| 658 | + |
| 659 | + if imm & (1 << i) != 0 { |
| 660 | + this.write_immediate(*sum, &dest)?; |
| 661 | + } else { |
| 662 | + this.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?; |
| 663 | + } |
| 664 | + } |
| 665 | + |
| 666 | + Ok(()) |
| 667 | +} |
| 668 | + |
| 669 | +/// Folds SIMD vectors `lhs` and `rhs` into a value of type `T` using `f`. |
| 670 | +fn bin_op_folded<'tcx, T>( |
| 671 | + this: &crate::MiriInterpCx<'_, 'tcx>, |
| 672 | + lhs: &OpTy<'tcx, Provenance>, |
| 673 | + rhs: &OpTy<'tcx, Provenance>, |
| 674 | + init: T, |
| 675 | + mut f: impl FnMut(T, ImmTy<'tcx, Provenance>, ImmTy<'tcx, Provenance>) -> InterpResult<'tcx, T>, |
| 676 | +) -> InterpResult<'tcx, T> { |
| 677 | + assert_eq!(lhs.layout, rhs.layout); |
| 678 | + |
| 679 | + let (lhs, lhs_len) = this.operand_to_simd(lhs)?; |
| 680 | + let (rhs, rhs_len) = this.operand_to_simd(rhs)?; |
| 681 | + |
| 682 | + assert_eq!(lhs_len, rhs_len); |
| 683 | + |
| 684 | + let mut acc = init; |
| 685 | + for i in 0..lhs_len { |
| 686 | + let lhs = this.project_index(&lhs, i)?; |
| 687 | + let rhs = this.project_index(&rhs, i)?; |
| 688 | + |
| 689 | + let lhs = this.read_immediate(&lhs)?; |
| 690 | + let rhs = this.read_immediate(&rhs)?; |
| 691 | + acc = f(acc, lhs, rhs)?; |
| 692 | + } |
| 693 | + |
| 694 | + Ok(acc) |
| 695 | +} |
0 commit comments