Skip to content

Commit 01500d8

Browse files
authored
Add matrix_apply_linter() (#1869)
1 parent e81967e commit 01500d8

10 files changed

+306
-2
lines changed

DESCRIPTION

+1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ Collate:
120120
'lintr-package.R'
121121
'literal_coercion_linter.R'
122122
'make_linter_from_regex.R'
123+
'matrix_apply_linter.R'
123124
'methods.R'
124125
'missing_argument_linter.R'
125126
'missing_package_linter.R'

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ export(lint_package)
8484
export(linters_with_defaults)
8585
export(linters_with_tags)
8686
export(literal_coercion_linter)
87+
export(matrix_apply_linter)
8788
export(missing_argument_linter)
8889
export(missing_package_linter)
8990
export(modify_defaults)

NEWS.md

+2
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107

108108
### New linters
109109

110+
* `matrix_apply_linter()` recommends use of dedicated `rowSums()`, `colSums()`, `colMeans()`, `rowMeans()` over `apply(., MARGIN, sum)` or `apply(., MARGIN, mean)`. The recommended alternative is much more efficient and more readable (#1869, @Bisaloo).
111+
110112
* `unnecessary_lambda_linter()`: detect unnecessary lambdas (anonymous functions), e.g.
111113
`lapply(x, function(xi) sum(xi))` can be `lapply(x, sum)` and `purrr::map(x, ~quantile(.x, 0.75, na.rm = TRUE))`
112114
can be `purrr::map(x, quantile, 0.75, na.rm = TRUE)`. Naming `probs = 0.75` can further improve readability (#1531, #1866, @MichaelChirico, @Bisaloo).

R/matrix_apply_linter.R

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#' Require usage of `colSums(x)` or `rowSums(x)` over `apply(x, ., sum)`
2+
#'
3+
#' [colSums()] and [rowSums()] are clearer and more performant alternatives to
4+
#' `apply(x, 2, sum)` and `apply(x, 1, sum)` respectively in the case of 2D
5+
#' arrays, or matrices
6+
#'
7+
#' @examples
8+
#' # will produce lints
9+
#' lint(
10+
#' text = "apply(x, 1, sum)",
11+
#' linters = matrix_apply_linter()
12+
#' )
13+
#'
14+
#' lint(
15+
#' text = "apply(x, 2, sum)",
16+
#' linters = matrix_apply_linter()
17+
#' )
18+
#'
19+
#' lint(
20+
#' text = "apply(x, 2, sum, na.rm = TRUE)",
21+
#' linters = matrix_apply_linter()
22+
#' )
23+
#'
24+
#' lint(
25+
#' text = "apply(x, 2:4, sum)",
26+
#' linters = matrix_apply_linter()
27+
#' )
28+
#'
29+
#' @evalRd rd_tags("matrix_apply_linter")
30+
#' @seealso [linters] for a complete list of linters available in lintr.
31+
#' @export
32+
matrix_apply_linter <- function() {
33+
34+
# mean() and sum() have very different signatures so we treat them separately.
35+
# sum() takes values to sum over via ..., has just one extra argument and is not a generic
36+
# mean() is a generic, takes values to average via a single argument, and can have extra arguments
37+
#
38+
# Currently supported values for MARGIN: scalar numeric and vector of contiguous values created by : (OP-COLON)
39+
sums_xpath <- "
40+
//SYMBOL_FUNCTION_CALL[text() = 'apply']
41+
/parent::expr
42+
/following-sibling::expr[
43+
NUM_CONST or OP-COLON/preceding-sibling::expr[NUM_CONST]/following-sibling::expr[NUM_CONST]
44+
and (position() = 2)
45+
]
46+
/following-sibling::expr[
47+
SYMBOL[text() = 'sum']
48+
and (position() = 1)
49+
]
50+
/parent::expr
51+
"
52+
53+
# Since mean() is a generic, we make sure that we only lint cases with arguments
54+
# supported by colMeans() and rowMeans(), i.e., na.rm
55+
means_xpath <- "
56+
//SYMBOL_FUNCTION_CALL[text() = 'apply']
57+
/parent::expr
58+
/following-sibling::expr[
59+
NUM_CONST or OP-COLON/preceding-sibling::expr[NUM_CONST]/following-sibling::expr[NUM_CONST]
60+
and (position() = 2)
61+
]
62+
/following-sibling::expr[
63+
SYMBOL[text() = 'mean']
64+
and (position() = 1)
65+
]
66+
/parent::expr[
67+
count(expr) = 4
68+
or (count(expr) = 5 and SYMBOL_SUB[text() = 'na.rm'])
69+
]
70+
"
71+
72+
xpath <- glue::glue("{sums_xpath} | {means_xpath}")
73+
74+
# This doesn't handle the case when MARGIN and FUN are named and in a different position
75+
# but this should be relatively rate
76+
var_xpath <- "expr[position() = 2]"
77+
margin_xpath <- "expr[position() = 3]"
78+
fun_xpath <- "expr[position() = 4]"
79+
80+
Linter(function(source_expression) {
81+
if (!is_lint_level(source_expression, "expression")) {
82+
return(list())
83+
}
84+
xml <- source_expression$xml_parsed_content
85+
86+
bad_expr <- xml2::xml_find_all(xml, xpath)
87+
88+
var <- xml2::xml_text(xml2::xml_find_all(bad_expr, var_xpath))
89+
90+
fun <- xml2::xml_text(xml2::xml_find_all(bad_expr, fun_xpath))
91+
fun <- tools::toTitleCase(fun)
92+
93+
margin <- xml2::xml_find_all(bad_expr, margin_xpath)
94+
95+
narm_val <- xml2::xml_text(
96+
xml2::xml_find_first(bad_expr, "SYMBOL_SUB[text() = 'na.rm']/following-sibling::expr")
97+
)
98+
99+
recos <- Map(craft_colsums_rowsums_msg, var, margin, fun, narm_val)
100+
101+
xml_nodes_to_lints(
102+
bad_expr,
103+
source_expression = source_expression,
104+
lint_message = sprintf("Use %1$s rather than %2$s", recos, get_r_string(bad_expr)),
105+
type = "warning"
106+
)
107+
})
108+
}
109+
110+
craft_colsums_rowsums_msg <- function(var, margin, fun, narm_val) {
111+
112+
if (is.na(xml2::xml_find_first(margin, "OP-COLON"))) {
113+
l1 <- xml2::xml_text(margin)
114+
l2 <- NULL
115+
} else {
116+
l1 <- xml2::xml_text(xml2::xml_find_first(margin, "expr[1]"))
117+
l2 <- xml2::xml_text(xml2::xml_find_first(margin, "expr[2]"))
118+
}
119+
120+
# See #1764 for details about various cases. In short:
121+
# - If apply(., 1:l2, sum) -> rowSums(., dims = l2)
122+
# - If apply(., l1:l2, sum) -> rowSums(colSums(., dims = l1 - 1), dims = l2 - l1 + 1)
123+
# - This last case can be simplified to a simple colSums() call if l2 = length(dim(.))
124+
# - dims argument can be dropped if equals to 1. This notably is the case for matrices
125+
if (is.null(l2)) {
126+
l2 <- l1
127+
}
128+
129+
# We don't want warnings when converted as NAs
130+
l1 <- suppressWarnings(as.integer(re_substitutes(l1, "L$", "")))
131+
l2 <- suppressWarnings(as.integer(re_substitutes(l2, "L$", "")))
132+
133+
if (!is.na(narm_val)) {
134+
narm <- glue::glue(", na.rm = {narm_val}")
135+
} else {
136+
narm <- ""
137+
}
138+
139+
if (identical(l1, 1L)) {
140+
reco <- glue::glue("row{fun}s({var}{narm}, dims = {l2})")
141+
} else {
142+
reco <- glue::glue(
143+
"row{fun}s(col{fun}s({var}{narm}, dims = {l1 - 1}), dims = {l2 - l1 + 1})",
144+
" or ",
145+
"col{fun}s({var}{narm}, dims = {l1 - 1}) if {var} has {l2} dimensions"
146+
)
147+
}
148+
149+
# It's easier to remove this after the fact, rather than having never ending if/elses
150+
reco <- gsub(", dims = 1", "", reco, fixed = TRUE)
151+
152+
return(reco)
153+
}

inst/lintr/linters.csv

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ is_numeric_linter,readability best_practices consistency
4343
lengths_linter,efficiency readability best_practices
4444
line_length_linter,style readability default configurable
4545
literal_coercion_linter,best_practices consistency efficiency
46+
matrix_apply_linter,readability efficiency
4647
missing_argument_linter,correctness common_mistakes configurable
4748
missing_package_linter,robustness common_mistakes
4849
namespace_linter,correctness robustness configurable executing

man/efficiency_linters.Rd

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/linters.Rd

+3-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/matrix_apply_linter.Rd

+42
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/readability_linters.Rd

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
test_that("matrix_apply_linter skips allowed usages", {
2+
linter <- matrix_apply_linter()
3+
4+
expect_lint("apply(x, 1, prod)", NULL, linter)
5+
6+
expect_lint("apply(x, 1, function(i) sum(i[i > 0]))", NULL, linter)
7+
8+
# sum as FUN argument
9+
expect_lint("apply(x, 1, f, sum)", NULL, linter)
10+
11+
# mean() with named arguments other than na.rm is skipped because they are not
12+
# implemented in colMeans() or rowMeans()
13+
expect_lint("apply(x, 1, mean, trim = 0.2)", NULL, linter)
14+
})
15+
16+
test_that("matrix_apply_linter is not implemented for complex MARGIN values", {
17+
linter <- matrix_apply_linter()
18+
19+
# Could be implemented at some point
20+
expect_lint("apply(x, seq(2, 4), sum)", NULL, linter)
21+
22+
# No equivalent
23+
expect_lint("apply(x, c(2, 4), sum)", NULL, linter)
24+
25+
# Beyond the scope of static analysis
26+
expect_lint("apply(x, m, sum)", NULL, linter)
27+
28+
expect_lint("apply(x, 1 + 2:4, sum)", NULL, linter)
29+
30+
})
31+
32+
33+
test_that("matrix_apply_linter simple disallowed usages", {
34+
linter <- matrix_apply_linter()
35+
lint_message <- rex::rex("rowSums(x)")
36+
37+
expect_lint("apply(x, 1, sum)", lint_message, linter)
38+
39+
expect_lint("apply(x, MARGIN = 1, FUN = sum)", lint_message, linter)
40+
41+
expect_lint("apply(x, 1L, sum)", lint_message, linter)
42+
43+
expect_lint("apply(x, 1:4, sum)", rex::rex("rowSums(x, dims = 4)"), linter)
44+
45+
expect_lint("apply(x, 2, sum)", rex::rex("rowSums(colSums(x))"), linter)
46+
47+
expect_lint("apply(x, 2:4, sum)", rex::rex("rowSums(colSums(x), dims = 3)"), linter)
48+
49+
lint_message <- rex::rex("rowMeans")
50+
51+
expect_lint("apply(x, 1, mean)", lint_message, linter)
52+
53+
expect_lint("apply(x, MARGIN = 1, FUN = mean)", lint_message, linter)
54+
55+
# Works with extra args in mean()
56+
expect_lint("apply(x, 1, mean, na.rm = TRUE)", lint_message, linter)
57+
58+
lint_message <- rex::rex("colMeans")
59+
60+
expect_lint("apply(x, 2, mean)", lint_message, linter)
61+
62+
expect_lint("apply(x, 2:4, mean)", lint_message, linter)
63+
64+
})
65+
66+
test_that("matrix_apply_linter recommendation includes na.rm if present in original call", {
67+
linter <- matrix_apply_linter()
68+
lint_message <- rex::rex("na.rm = TRUE")
69+
70+
expect_lint("apply(x, 1, sum, na.rm = TRUE)", lint_message, linter)
71+
72+
expect_lint("apply(x, 2, sum, na.rm = TRUE)", lint_message, linter)
73+
74+
expect_lint("apply(x, 1, mean, na.rm = TRUE)", lint_message, linter)
75+
76+
expect_lint("apply(x, 2, mean, na.rm = TRUE)", lint_message, linter)
77+
78+
lint_message <- rex::rex("rowSums(x)")
79+
expect_lint("apply(x, 1, sum)", lint_message, linter)
80+
81+
lint_message <- rex::rex("na.rm = foo")
82+
expect_lint("apply(x, 1, sum, na.rm = foo)", lint_message, linter)
83+
84+
})
85+
86+
test_that("matrix_apply_linter works with multiple lints in a single expression", {
87+
linter <- matrix_apply_linter()
88+
89+
expect_lint(
90+
"rbind(
91+
apply(x, 1, sum),
92+
apply(y, 2:4, mean, na.rm = TRUE)
93+
)",
94+
list(
95+
rex::rex("Use rowSums(x)"),
96+
rex::rex("Use rowMeans(colMeans(y, na.rm = TRUE), dims = 3) or colMeans(y, na.rm = TRUE) if y has 4 dimensions")
97+
),
98+
linter
99+
)
100+
101+
})

0 commit comments

Comments
 (0)