diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 115cd2d71..282c382be 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1418,43 +1418,65 @@ where /// The windows are all distinct overlapping views of size `window_size` /// that fit into the array's shape. /// - /// This produces no elements if the window size is larger than the actual array size along any - /// axis. + /// This is essentially equivalent to [`.windows_with_stride()`] with unit stride. + pub fn windows(&self, window_size: E) -> Windows<'_, A, D> + where + E: IntoDimension, + S: Data, + { + Windows::new(self.view(), window_size) + } + + /// Return a window producer and iterable. + /// + /// The windows are all distinct views of size `window_size` + /// that fit into the array's shape. + /// + /// The stride is ordered by the outermost axis.
+ /// Hence, a (x₀, x₁, ..., xₙ) stride will be applied to + /// (A₀, A₁, ..., Aₙ) where Aₓ stands for `Axis(x)`. + /// + /// This produces all windows that fit within the array for the given stride, + /// assuming the window size is not larger than the array size. /// /// The produced element is an `ArrayView` with exactly the dimension /// `window_size`. + /// + /// Note that passing a stride of only ones is similar to + /// calling [`ArrayBase::windows()`]. /// - /// **Panics** if any dimension of `window_size` is zero.
- /// (**Panics** if `D` is `IxDyn` and `window_size` does not match the + /// **Panics** if any dimension of `window_size` or `stride` is zero.
+ /// (**Panics** if `D` is `IxDyn` and `window_size` or `stride` does not match the /// number of array axes.) /// - /// This is an illustration of the 2×2 windows in a 3×4 array: + /// This is the same illustration found in [`ArrayBase::windows()`], + /// 2×2 windows in a 3×4 array, but now with a (1, 2) stride: /// /// ```text /// ──▶ Axis(1) /// - /// │ ┏━━━━━┳━━━━━┱─────┬─────┐ ┌─────┲━━━━━┳━━━━━┱─────┐ ┌─────┬─────┲━━━━━┳━━━━━┓ - /// ▼ ┃ a₀₀ ┃ a₀₁ ┃ │ │ │ ┃ a₀₁ ┃ a₀₂ ┃ │ │ │ ┃ a₀₂ ┃ a₀₃ ┃ - /// Axis(0) ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────╊━━━━━╋━━━━━╉─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ - /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ ┃ a₁₁ ┃ a₁₂ ┃ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ - /// ┡━━━━━╇━━━━━╃─────┼─────┤ ├─────╄━━━━━╇━━━━━╃─────┤ ├─────┼─────╄━━━━━╇━━━━━┩ - /// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ - /// └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ - /// - /// ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐ - /// │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ - /// ┢━━━━━╈━━━━━╅─────┼─────┤ ├─────╆━━━━━╈━━━━━╅─────┤ ├─────┼─────╆━━━━━╈━━━━━┪ - /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ ┃ a₁₁ ┃ a₁₂ ┃ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ - /// ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────╊━━━━━╋━━━━━╉─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ - /// ┃ a₂₀ ┃ a₂₁ ┃ │ │ │ ┃ a₂₁ ┃ a₂₂ ┃ │ │ │ ┃ a₂₂ ┃ a₂₃ ┃ - /// ┗━━━━━┻━━━━━┹─────┴─────┘ └─────┺━━━━━┻━━━━━┹─────┘ └─────┴─────┺━━━━━┻━━━━━┛ + /// │ ┏━━━━━┳━━━━━┱─────┬─────┐ ┌─────┬─────┲━━━━━┳━━━━━┓ + /// ▼ ┃ a₀₀ ┃ a₀₁ ┃ │ │ │ │ ┃ a₀₂ ┃ a₀₃ ┃ + /// Axis(0) ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ + /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ + /// ┡━━━━━╇━━━━━╃─────┼─────┤ ├─────┼─────╄━━━━━╇━━━━━┩ + /// │ │ │ │ │ │ │ │ │ │ + /// └─────┴─────┴─────┴─────┘ └─────┴─────┴─────┴─────┘ + /// + /// ┌─────┬─────┬─────┬─────┐ ┌─────┬─────┬─────┬─────┐ + /// │ │ │ │ │ │ │ │ │ │ + /// ┢━━━━━╈━━━━━╅─────┼─────┤ ├─────┼─────╆━━━━━╈━━━━━┪ + /// ┃ a₁₀ ┃ a₁₁ ┃ │ │ │ │ ┃ a₁₂ ┃ a₁₃ ┃ + /// ┣━━━━━╋━━━━━╉─────┼─────┤ ├─────┼─────╊━━━━━╋━━━━━┫ + /// ┃ a₂₀ ┃ a₂₁ ┃ │ │ │ │ ┃ a₂₂ ┃ a₂₃ ┃ + /// ┗━━━━━┻━━━━━┹─────┴─────┘ └─────┴─────┺━━━━━┻━━━━━┛ /// ``` - pub fn windows(&self, window_size: E) -> Windows<'_, A, D> + pub fn windows_with_stride(&self, window_size: E, stride: E) -> Windows<'_, A, D> where E: IntoDimension, S: Data, { - Windows::new(self.view(), window_size) + Windows::new_with_stride(self.view(), window_size, stride) } /// Returns a producer which traverses over all windows of a given length along an axis. diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index c47bfecec..d84b0e7b8 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -20,6 +20,20 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> { E: IntoDimension, { let window = window_size.into_dimension(); + let ndim = window.ndim(); + + let mut unit_stride = D::zeros(ndim); + unit_stride.slice_mut().fill(1); + + Windows::new_with_stride(a, window, unit_stride) + } + + pub(crate) fn new_with_stride(a: ArrayView<'a, A, D>, window_size: E, strides: E) -> Self + where + E: IntoDimension, + { + let window = window_size.into_dimension(); + let strides_d = strides.into_dimension(); ndassert!( a.ndim() == window.ndim(), concat!( @@ -30,18 +44,42 @@ impl<'a, A, D: Dimension> Windows<'a, A, D> { a.ndim(), a.shape() ); + ndassert!( + a.ndim() == strides_d.ndim(), + concat!( + "Stride dimension {} does not match array dimension {} ", + "(with array of shape {:?})" + ), + strides_d.ndim(), + a.ndim(), + a.shape() + ); let mut size = a.dim; - for (sz, &ws) in size.slice_mut().iter_mut().zip(window.slice()) { + for ((sz, &ws), &stride) in size + .slice_mut() + .iter_mut() + .zip(window.slice()) + .zip(strides_d.slice()) + { assert_ne!(ws, 0, "window-size must not be zero!"); + assert_ne!(stride, 0, "stride cannot have a dimension as zero!"); // cannot use std::cmp::max(0, ..) since arithmetic underflow panics - *sz = if *sz < ws { 0 } else { *sz - ws + 1 }; + *sz = if *sz < ws { + 0 + } else { + ((*sz - (ws - 1) - 1) / stride) + 1 + }; } - let window_strides = a.strides.clone(); + let mut array_strides = a.strides.clone(); + for (arr_stride, ix_stride) in array_strides.slice_mut().iter_mut().zip(strides_d.slice()) { + *arr_stride *= ix_stride; + } + unsafe { Windows { - base: ArrayView::new(a.ptr, size, a.strides), + base: ArrayView::new(a.ptr, size, array_strides), window, strides: window_strides, } diff --git a/tests/windows.rs b/tests/windows.rs index 432be5e41..2c928aaef 100644 --- a/tests/windows.rs +++ b/tests/windows.rs @@ -30,7 +30,7 @@ fn windows_iterator_zero_size() { a.windows(Dim((0, 0, 0))); } -/// Test that verifites that no windows are yielded on oversized window sizes. +/// Test that verifies that no windows are yielded on oversized window sizes. #[test] fn windows_iterator_oversized() { let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap(); @@ -95,6 +95,76 @@ fn windows_iterator_3d() { ); } +/// Test that verifies the `Windows` iterator panics when stride has an axis equal to zero. +#[test] +#[should_panic] +fn windows_iterator_stride_axis_zero() { + let a = Array::from_iter(10..37).into_shape((3, 3, 3)).unwrap(); + a.windows_with_stride((2, 2, 2), (0, 2, 2)); +} + +/// Test that verifies that only first window is yielded when stride is oversized on every axis. +#[test] +fn windows_iterator_only_one_valid_window_for_oversized_stride() { + let a = Array::from_iter(10..135).into_shape((5, 5, 5)).unwrap(); + let mut iter = a.windows_with_stride((2, 2, 2), (8, 8, 8)).into_iter(); // (4,3,2) doesn't fit into (3,3,3) => oversized! + itertools::assert_equal( + iter.next(), + Some(arr3(&[[[10, 11], [15, 16]], [[35, 36], [40, 41]]])), + ); +} + +/// Simple test for iterating 1d-arrays via `Windows` with stride. +#[test] +fn windows_iterator_1d_with_stride() { + let a = Array::from_iter(10..20).into_shape(10).unwrap(); + itertools::assert_equal( + a.windows_with_stride(4, 2), + vec![ + arr1(&[10, 11, 12, 13]), + arr1(&[12, 13, 14, 15]), + arr1(&[14, 15, 16, 17]), + arr1(&[16, 17, 18, 19]), + ], + ); +} + +/// Simple test for iterating 2d-arrays via `Windows` with stride. +#[test] +fn windows_iterator_2d_with_stride() { + let a = Array::from_iter(10..30).into_shape((5, 4)).unwrap(); + itertools::assert_equal( + a.windows_with_stride((3, 2), (2, 1)), + vec![ + arr2(&[[10, 11], [14, 15], [18, 19]]), + arr2(&[[11, 12], [15, 16], [19, 20]]), + arr2(&[[12, 13], [16, 17], [20, 21]]), + arr2(&[[18, 19], [22, 23], [26, 27]]), + arr2(&[[19, 20], [23, 24], [27, 28]]), + arr2(&[[20, 21], [24, 25], [28, 29]]), + ], + ); +} + +/// Simple test for iterating 3d-arrays via `Windows` with stride. +#[test] +fn windows_iterator_3d_with_stride() { + let a = Array::from_iter(10..74).into_shape((4, 4, 4)).unwrap(); + itertools::assert_equal( + a.windows_with_stride((2, 2, 2), (2, 2, 2)), + vec![ + arr3(&[[[10, 11], [14, 15]], [[26, 27], [30, 31]]]), + arr3(&[[[12, 13], [16, 17]], [[28, 29], [32, 33]]]), + arr3(&[[[18, 19], [22, 23]], [[34, 35], [38, 39]]]), + arr3(&[[[20, 21], [24, 25]], [[36, 37], [40, 41]]]), + arr3(&[[[42, 43], [46, 47]], [[58, 59], [62, 63]]]), + arr3(&[[[44, 45], [48, 49]], [[60, 61], [64, 65]]]), + arr3(&[[[50, 51], [54, 55]], [[66, 67], [70, 71]]]), + arr3(&[[[52, 53], [56, 57]], [[68, 69], [72, 73]]]), + ], + ); +} + #[test] fn test_window_zip() { let a = Array::from_iter(0..64).into_shape((4, 4, 4)).unwrap();