Skip to content

Commit 3986824

Browse files
committed
Adds triu and tril methods that mimic NumPy.
Includes branched implementations for f- and c-order arrays.
1 parent e734ce8 commit 3986824

File tree

2 files changed

+291
-0
lines changed

2 files changed

+291
-0
lines changed

src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1616,3 +1616,6 @@ pub(crate) fn is_aligned<T>(ptr: *const T) -> bool
16161616
{
16171617
(ptr as usize) % ::std::mem::align_of::<T>() == 0
16181618
}
1619+
1620+
// Triangular constructors
1621+
mod tri;

src/tri.rs

+288
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
// Copyright 2014-2024 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use core::cmp::{max, min};
10+
11+
use num_traits::Zero;
12+
13+
use crate::{dimension::is_layout_f, Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip};
14+
15+
impl<S, A, D> ArrayBase<S, D>
16+
where
17+
S: Data<Elem = A>,
18+
D: Dimension,
19+
A: Clone + Zero,
20+
D::Smaller: Copy,
21+
{
22+
/// Upper triangular of an array.
23+
///
24+
/// Return a copy of the array with elements below the *k*-th diagonal zeroed.
25+
/// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes.
26+
/// For 0D and 1D arrays, `triu` will return an unchanged clone.
27+
///
28+
/// See also [`ArrayBase::tril`]
29+
///
30+
/// ```
31+
/// use ndarray::array;
32+
///
33+
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
34+
/// let res = arr.triu(0);
35+
/// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
36+
/// ```
37+
pub fn triu(&self, k: isize) -> Array<A, D>
38+
{
39+
match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) {
40+
true => {
41+
let n = self.ndim();
42+
let mut x = self.view();
43+
x.swap_axes(n - 2, n - 1);
44+
let mut tril = x.tril(-k);
45+
tril.swap_axes(n - 2, n - 1);
46+
47+
tril
48+
}
49+
false => {
50+
let mut res = Array::zeros(self.raw_dim());
51+
Zip::indexed(self.rows())
52+
.and(res.rows_mut())
53+
.for_each(|i, src, mut dst| {
54+
let row_num = i.into_dimension().last_elem();
55+
let lower = max(row_num as isize + k, 0);
56+
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
57+
});
58+
59+
res
60+
}
61+
}
62+
}
63+
64+
/// Lower triangular of an array.
65+
///
66+
/// Return a copy of the array with elements above the *k*-th diagonal zeroed.
67+
/// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes.
68+
/// For 0D and 1D arrays, `tril` will return an unchanged clone.
69+
///
70+
/// See also [`ArrayBase::triu`]
71+
///
72+
/// ```
73+
/// use ndarray::array;
74+
///
75+
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
76+
/// let res = arr.tril(0);
77+
/// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
78+
/// ```
79+
pub fn tril(&self, k: isize) -> Array<A, D>
80+
{
81+
match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) {
82+
true => {
83+
let n = self.ndim();
84+
let mut x = self.view();
85+
x.swap_axes(n - 2, n - 1);
86+
let mut tril = x.triu(-k);
87+
tril.swap_axes(n - 2, n - 1);
88+
89+
tril
90+
}
91+
false => {
92+
let mut res = Array::zeros(self.raw_dim());
93+
Zip::indexed(self.rows())
94+
.and(res.rows_mut())
95+
.for_each(|i, src, mut dst| {
96+
// This ncols must go inside the loop to avoid panic on 1D arrays.
97+
// Statistically-neglible difference in performance vs defining ncols at top.
98+
let ncols = src.len_of(Axis(src.ndim() - 1)) as isize;
99+
let row_num = i.into_dimension().last_elem();
100+
let upper = min(row_num as isize + k, ncols) + 1;
101+
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
102+
});
103+
104+
res
105+
}
106+
}
107+
}
108+
}
109+
110+
#[cfg(test)]
111+
mod tests
112+
{
113+
use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
114+
use std::vec;
115+
116+
#[test]
117+
fn test_keep_order()
118+
{
119+
let x = Array2::<f64>::ones((3, 3).f());
120+
let res = x.triu(0);
121+
assert!(dimension::is_layout_f(&res.dim, &res.strides));
122+
123+
let res = x.tril(0);
124+
assert!(dimension::is_layout_f(&res.dim, &res.strides));
125+
}
126+
127+
#[test]
128+
fn test_0d()
129+
{
130+
let x = Array0::<f64>::ones(());
131+
let res = x.triu(0);
132+
assert_eq!(res, x);
133+
134+
let res = x.tril(0);
135+
assert_eq!(res, x);
136+
137+
let x = Array0::<f64>::ones(().f());
138+
let res = x.triu(0);
139+
assert_eq!(res, x);
140+
141+
let res = x.tril(0);
142+
assert_eq!(res, x);
143+
}
144+
145+
#[test]
146+
fn test_1d()
147+
{
148+
let x = array![1, 2, 3];
149+
let res = x.triu(0);
150+
assert_eq!(res, x);
151+
152+
let res = x.triu(0);
153+
assert_eq!(res, x);
154+
155+
let x = Array1::<f64>::ones(3.f());
156+
let res = x.triu(0);
157+
assert_eq!(res, x);
158+
159+
let res = x.triu(0);
160+
assert_eq!(res, x);
161+
}
162+
163+
#[test]
164+
fn test_2d()
165+
{
166+
let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
167+
168+
// Upper
169+
let res = x.triu(0);
170+
assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
171+
172+
// Lower
173+
let res = x.tril(0);
174+
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
175+
176+
let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap();
177+
178+
// Upper
179+
let res = x.triu(0);
180+
assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
181+
182+
// Lower
183+
let res = x.tril(0);
184+
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
185+
}
186+
187+
#[test]
188+
fn test_3d()
189+
{
190+
let x = array![
191+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
192+
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
193+
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
194+
];
195+
196+
// Upper
197+
let res = x.triu(0);
198+
assert_eq!(
199+
res,
200+
array![
201+
[[1, 2, 3], [0, 5, 6], [0, 0, 9]],
202+
[[10, 11, 12], [0, 14, 15], [0, 0, 18]],
203+
[[19, 20, 21], [0, 23, 24], [0, 0, 27]]
204+
]
205+
);
206+
207+
// Lower
208+
let res = x.tril(0);
209+
assert_eq!(
210+
res,
211+
array![
212+
[[1, 0, 0], [4, 5, 0], [7, 8, 9]],
213+
[[10, 0, 0], [13, 14, 0], [16, 17, 18]],
214+
[[19, 0, 0], [22, 23, 0], [25, 26, 27]]
215+
]
216+
);
217+
218+
let x = Array3::from_shape_vec(
219+
(3, 3, 3).f(),
220+
vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27],
221+
)
222+
.unwrap();
223+
224+
// Upper
225+
let res = x.triu(0);
226+
assert_eq!(
227+
res,
228+
array![
229+
[[1, 2, 3], [0, 5, 6], [0, 0, 9]],
230+
[[10, 11, 12], [0, 14, 15], [0, 0, 18]],
231+
[[19, 20, 21], [0, 23, 24], [0, 0, 27]]
232+
]
233+
);
234+
235+
// Lower
236+
let res = x.tril(0);
237+
assert_eq!(
238+
res,
239+
array![
240+
[[1, 0, 0], [4, 5, 0], [7, 8, 9]],
241+
[[10, 0, 0], [13, 14, 0], [16, 17, 18]],
242+
[[19, 0, 0], [22, 23, 0], [25, 26, 27]]
243+
]
244+
);
245+
}
246+
247+
#[test]
248+
fn test_off_axis()
249+
{
250+
let x = array![
251+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
252+
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
253+
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
254+
];
255+
256+
let res = x.triu(1);
257+
assert_eq!(
258+
res,
259+
array![
260+
[[0, 2, 3], [0, 0, 6], [0, 0, 0]],
261+
[[0, 11, 12], [0, 0, 15], [0, 0, 0]],
262+
[[0, 20, 21], [0, 0, 24], [0, 0, 0]]
263+
]
264+
);
265+
266+
let res = x.triu(-1);
267+
assert_eq!(
268+
res,
269+
array![
270+
[[1, 2, 3], [4, 5, 6], [0, 8, 9]],
271+
[[10, 11, 12], [13, 14, 15], [0, 17, 18]],
272+
[[19, 20, 21], [22, 23, 24], [0, 26, 27]]
273+
]
274+
);
275+
}
276+
277+
#[test]
278+
fn test_odd_shape()
279+
{
280+
let x = array![[1, 2, 3], [4, 5, 6]];
281+
let res = x.triu(0);
282+
assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);
283+
284+
let x = array![[1, 2], [3, 4], [5, 6]];
285+
let res = x.triu(0);
286+
assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);
287+
}
288+
}

0 commit comments

Comments
 (0)