Skip to content

Commit 7d24c54

Browse files
committed
Adds triu and tril methods that mimic NumPy.
Includes branched implementations for f- and c-order arrays.
1 parent e734ce8 commit 7d24c54

File tree

2 files changed

+283
-0
lines changed

2 files changed

+283
-0
lines changed

src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1616,3 +1616,6 @@ pub(crate) fn is_aligned<T>(ptr: *const T) -> bool
16161616
{
16171617
(ptr as usize) % ::std::mem::align_of::<T>() == 0
16181618
}
1619+
1620+
// Triangular constructors
1621+
mod tri;

src/tri.rs

+280
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
// Copyright 2014-2024 bluss and ndarray developers.
2+
//
3+
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4+
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5+
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6+
// option. This file may not be copied, modified, or distributed
7+
// except according to those terms.
8+
9+
use core::cmp::{max, min};
10+
11+
use num_traits::Zero;
12+
13+
use crate::{
14+
dimension::is_layout_f,
15+
Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip,
16+
};
17+
18+
impl<S, A, D> ArrayBase<S, D>
19+
where
20+
S: Data<Elem = A>,
21+
D: Dimension,
22+
A: Clone + Zero,
23+
D::Smaller: Copy,
24+
{
25+
/// Upper triangular of an array.
26+
///
27+
/// Return a copy of the array with elements below the *k*-th diagonal zeroed.
28+
/// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes.
29+
/// For 0D and 1D arrays, `triu` will return an unchanged clone.
30+
///
31+
/// See also [`ArrayBase::tril`]
32+
///
33+
/// ```
34+
/// use ndarray::array;
35+
///
36+
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
37+
/// let res = arr.triu(0);
38+
/// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
39+
/// ```
40+
pub fn triu(&self, k: isize) -> Array<A, D> {
41+
match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) {
42+
true => {
43+
let n = self.ndim();
44+
let mut x = self.view();
45+
x.swap_axes(n - 2, n - 1);
46+
let mut tril = x.tril(-k);
47+
tril.swap_axes(n - 2, n - 1);
48+
49+
tril
50+
},
51+
false => {
52+
let mut res = Array::zeros(self.raw_dim());
53+
Zip::indexed(self.rows())
54+
.and(res.rows_mut())
55+
.for_each(|i, src, mut dst| {
56+
let row_num = i.into_dimension().last_elem();
57+
let lower = max(row_num as isize + k, 0);
58+
dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..]));
59+
});
60+
61+
res
62+
}
63+
}
64+
}
65+
66+
/// Lower triangular of an array.
67+
///
68+
/// Return a copy of the array with elements above the *k*-th diagonal zeroed.
69+
/// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes.
70+
/// For 0D and 1D arrays, `tril` will return an unchanged clone.
71+
///
72+
/// See also [`ArrayBase::triu`]
73+
///
74+
/// ```
75+
/// use ndarray::array;
76+
///
77+
/// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
78+
/// let res = arr.tril(0);
79+
/// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
80+
/// ```
81+
pub fn tril(&self, k: isize) -> Array<A, D> {
82+
match self.ndim() > 1 && is_layout_f(&self.dim, &self.strides) {
83+
true => {
84+
let n = self.ndim();
85+
let mut x = self.view();
86+
x.swap_axes(n - 2, n - 1);
87+
let mut tril = x.triu(-k);
88+
tril.swap_axes(n - 2, n - 1);
89+
90+
tril
91+
},
92+
false => {
93+
let mut res = Array::zeros(self.raw_dim());
94+
Zip::indexed(self.rows())
95+
.and(res.rows_mut())
96+
.for_each(|i, src, mut dst| {
97+
// This ncols must go inside the loop to avoid panic on 1D arrays.
98+
// Statistically-neglible difference in performance vs defining ncols at top.
99+
let ncols = src.len_of(Axis(src.ndim() - 1)) as isize;
100+
let row_num = i.into_dimension().last_elem();
101+
let upper = min(row_num as isize + k, ncols) + 1;
102+
dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper]));
103+
});
104+
105+
res
106+
}
107+
}
108+
}
109+
}
110+
111+
#[cfg(test)]
112+
mod tests {
113+
use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
114+
115+
#[test]
116+
fn test_keep_order() {
117+
let x = Array2::<f64>::ones((3, 3).f());
118+
let res = x.triu(0);
119+
assert!(dimension::is_layout_f(&res.dim, &res.strides));
120+
121+
let res = x.tril(0);
122+
assert!(dimension::is_layout_f(&res.dim, &res.strides));
123+
}
124+
125+
#[test]
126+
fn test_0d() {
127+
let x = Array0::<f64>::ones(());
128+
let res = x.triu(0);
129+
assert_eq!(res, x);
130+
131+
let res = x.tril(0);
132+
assert_eq!(res, x);
133+
134+
let x = Array0::<f64>::ones(().f());
135+
let res = x.triu(0);
136+
assert_eq!(res, x);
137+
138+
let res = x.tril(0);
139+
assert_eq!(res, x);
140+
}
141+
142+
#[test]
143+
fn test_1d() {
144+
let x = array![1, 2, 3];
145+
let res = x.triu(0);
146+
assert_eq!(res, x);
147+
148+
let res = x.triu(0);
149+
assert_eq!(res, x);
150+
151+
let x = Array1::<f64>::ones(3.f());
152+
let res = x.triu(0);
153+
assert_eq!(res, x);
154+
155+
let res = x.triu(0);
156+
assert_eq!(res, x);
157+
}
158+
159+
#[test]
160+
fn test_2d() {
161+
let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
162+
163+
// Upper
164+
let res = x.triu(0);
165+
assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
166+
167+
// Lower
168+
let res = x.tril(0);
169+
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
170+
171+
let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap();
172+
173+
// Upper
174+
let res = x.triu(0);
175+
assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
176+
177+
// Lower
178+
let res = x.tril(0);
179+
assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
180+
}
181+
182+
#[test]
183+
fn test_3d() {
184+
let x = array![
185+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
186+
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
187+
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
188+
];
189+
190+
// Upper
191+
let res = x.triu(0);
192+
assert_eq!(
193+
res,
194+
array![
195+
[[1, 2, 3], [0, 5, 6], [0, 0, 9]],
196+
[[10, 11, 12], [0, 14, 15], [0, 0, 18]],
197+
[[19, 20, 21], [0, 23, 24], [0, 0, 27]]
198+
]
199+
);
200+
201+
// Lower
202+
let res = x.tril(0);
203+
assert_eq!(
204+
res,
205+
array![
206+
[[1, 0, 0], [4, 5, 0], [7, 8, 9]],
207+
[[10, 0, 0], [13, 14, 0], [16, 17, 18]],
208+
[[19, 0, 0], [22, 23, 0], [25, 26, 27]]
209+
]
210+
);
211+
212+
let x = Array3::from_shape_vec(
213+
(3, 3, 3).f(),
214+
vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27],
215+
)
216+
.unwrap();
217+
218+
// Upper
219+
let res = x.triu(0);
220+
assert_eq!(
221+
res,
222+
array![
223+
[[1, 2, 3], [0, 5, 6], [0, 0, 9]],
224+
[[10, 11, 12], [0, 14, 15], [0, 0, 18]],
225+
[[19, 20, 21], [0, 23, 24], [0, 0, 27]]
226+
]
227+
);
228+
229+
// Lower
230+
let res = x.tril(0);
231+
assert_eq!(
232+
res,
233+
array![
234+
[[1, 0, 0], [4, 5, 0], [7, 8, 9]],
235+
[[10, 0, 0], [13, 14, 0], [16, 17, 18]],
236+
[[19, 0, 0], [22, 23, 0], [25, 26, 27]]
237+
]
238+
);
239+
}
240+
241+
#[test]
242+
fn test_off_axis() {
243+
let x = array![
244+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
245+
[[10, 11, 12], [13, 14, 15], [16, 17, 18]],
246+
[[19, 20, 21], [22, 23, 24], [25, 26, 27]]
247+
];
248+
249+
let res = x.triu(1);
250+
assert_eq!(
251+
res,
252+
array![
253+
[[0, 2, 3], [0, 0, 6], [0, 0, 0]],
254+
[[0, 11, 12], [0, 0, 15], [0, 0, 0]],
255+
[[0, 20, 21], [0, 0, 24], [0, 0, 0]]
256+
]
257+
);
258+
259+
let res = x.triu(-1);
260+
assert_eq!(
261+
res,
262+
array![
263+
[[1, 2, 3], [4, 5, 6], [0, 8, 9]],
264+
[[10, 11, 12], [13, 14, 15], [0, 17, 18]],
265+
[[19, 20, 21], [22, 23, 24], [0, 26, 27]]
266+
]
267+
);
268+
}
269+
270+
#[test]
271+
fn test_odd_shape() {
272+
let x = array![[1, 2, 3], [4, 5, 6]];
273+
let res = x.triu(0);
274+
assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);
275+
276+
let x = array![[1, 2], [3, 4], [5, 6]];
277+
let res = x.triu(0);
278+
assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);
279+
}
280+
}

0 commit comments

Comments
 (0)