Skip to content

Commit

Permalink
Merge pull request #397 from tidyverse/implement-pick
Browse files Browse the repository at this point in the history
Implement `pick()`
  • Loading branch information
markfairbanks authored Nov 14, 2022
2 parents 5d71f4c + da7029c commit 5750999
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 14 deletions.
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

* `min_rank()`, `dense_rank()`, `percent_rank()`, & `cume_dist()` are now translated
to their `data.table` equivalents (#396)

* `pick()` is now translated (#341)

* `across()` output can now be used as a data frame (#341)

* `names_glue` now works in `pivot_wider()` when `names_from` contains `NA`s (#394)

Expand Down
8 changes: 6 additions & 2 deletions R/tidyeval-across.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ capture_across <- function(data, x, j = TRUE) {
dt_squash_across(get_expr(x), get_env(x), data, j)
}

dt_squash_across <- function(call, env, data, j = j) {
dt_squash_across <- function(call, env, data, j = j, is_top = TRUE) {
call <- match.call(dplyr::across, call, expand.dots = FALSE, envir = env)
across_setup(data, call, env, allow_rename = TRUE, j = j, fn = "across()")
out <- across_setup(data, call, env, allow_rename = TRUE, j = j, fn = "across()")
if (is_false(is_top)) {
out <- call2("data.table", !!!out)
}
out
}

capture_if_all <- function(data, x, j = TRUE) {
Expand Down
27 changes: 16 additions & 11 deletions R/tidyeval.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ globalVariables(dt_funs)

capture_dots <- function(.data, ..., .j = TRUE) {
dots <- enquos(..., .named = .j)
dots <- lapply(dots, dt_squash, data = .data, j = .j)
dots <- map(dots, dt_squash, data = .data, j = .j)

# Remove names from any list elements
is_list <- map_lgl(dots, is.list)
Expand All @@ -44,7 +44,8 @@ capture_dots <- function(.data, ..., .j = TRUE) {
capture_new_vars <- function(.data, ...) {
dots <- as.list(enquos(..., .named = TRUE))
for (i in seq_along(dots)) {
dot <- dt_squash(dots[[i]], data = .data)
dot <- dots[[i]]
dot <- dt_squash(dot, data = .data)
if (is.null(dot)) {
dots[i] <- list(NULL)
} else {
Expand All @@ -67,7 +68,7 @@ capture_dot <- function(.data, x, j = TRUE) {
}

# squash quosures
dt_squash <- function(x, env, data, j = TRUE) {
dt_squash <- function(x, env, data, j = TRUE, is_top = TRUE) {
if (is_atomic(x) || is_null(x)) {
x
} else if (is_symbol(x)) {
Expand Down Expand Up @@ -101,21 +102,25 @@ dt_squash <- function(x, env, data, j = TRUE) {
}
}
} else if (is_quosure(x)) {
dt_squash(get_expr(x), get_env(x), data, j = j)
dt_squash(get_expr(x), get_env(x), data, j = j, is_top)
} else if (is_call(x, "if_any")) {
dt_squash_if(x, env, data, j = j, reduce = "|")
} else if (is_call(x, "if_all")) {
dt_squash_if(x, env, data, j = j, reduce = "&")
} else if (is_call(x, "across")) {
dt_squash_across(x, env, data, j = j)
dt_squash_across(x, env, data, j = j, is_top)
} else if (is_call(x, "pick")) {
x[[1]] <- sym("c")
call <- call2("across", x)
dt_squash_across(call, env, data, j, is_top)
} else if (is_call(x)) {
dt_squash_call(x, env, data, j = j)
dt_squash_call(x, env, data, j = j, is_top)
} else {
abort("Invalid input")
}
}

dt_squash_call <- function(x, env, data, j = TRUE) {
dt_squash_call <- function(x, env, data, j = TRUE, is_top = TRUE) {
if (is_mask_pronoun(x)) {
var <- x[[3]]
if (is_call(x, "[[")) {
Expand All @@ -127,7 +132,7 @@ dt_squash_call <- function(x, env, data, j = TRUE) {
sym(paste0("..", var))
}
} else if (is_call(x, c("coalesce", "replace_na"))) {
args <- lapply(x[-1], dt_squash, env = env, data = data, j = j)
args <- lapply(x[-1], dt_squash, env, data, j, is_top)
call2("fcoalesce", !!!args)
} else if (is_call(x, "case_when")) {
# case_when(x ~ y) -> fcase(x, y)
Expand All @@ -139,7 +144,7 @@ dt_squash_call <- function(x, env, data, j = TRUE) {
x[[3]]
)
}))
args <- lapply(args, dt_squash, env = env, data = data, j = j)
args <- lapply(args, dt_squash, env = env, data = data, j = j, is_top)
call2("fcase", !!!args)
} else if (is_call(x, "cur_data")) {
quote(.SD)
Expand All @@ -164,7 +169,7 @@ dt_squash_call <- function(x, env, data, j = TRUE) {
}

x[[1]] <- quote(fifelse)
x[-1] <- lapply(x[-1], dt_squash, env, data, j = j)
x[-1] <- lapply(x[-1], dt_squash, env, data, j = j, is_top)
x
} else if (is_call(x, c("lag", "lead"))) {
if (is_call(x, "lag")) {
Expand Down Expand Up @@ -237,7 +242,7 @@ dt_squash_call <- function(x, env, data, j = TRUE) {
}
call
} else {
x[-1] <- lapply(x[-1], dt_squash, env, data, j = j)
x[-1] <- lapply(x[-1], dt_squash, env, data, j = j, is_top = FALSE)
x
}
}
Expand Down
31 changes: 30 additions & 1 deletion tests/testthat/test-tidyeval-across.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,35 @@ test_that("across() .cols is evaluated in across()'s calling environment", {
)
})

test_that("across() output can be used as a data frame", {
df <- lazy_dt(tibble(x = 1:3, y = 1:3, z = c("a", "a", "b")))
res <- df %>%
mutate(across_df = rowSums(across(c(x, y), ~ .x + 1))) %>%
collect()

expect_named(res, c("x", "y", "z", "across_df"))
expect_equal(res$across_df, c(4, 6, 8))

expr <- dt_squash(expr(across(c(x, y), ~ .x + 1)), df$env, df, is_top = FALSE)
expect_equal(expr, expr(data.table(x = x + 1, y = y + 1)))
})

test_that("pick() works", {
df <- lazy_dt(tibble(x = 1:3, y = 1:3, z = c("a", "a", "b")))
res <- df %>%
mutate(row_sum = rowSums(pick(x, y))) %>%
collect()

expect_named(res, c("x", "y", "z", "row_sum"))
expect_equal(res$row_sum, c(2, 4, 6))

expr <- dt_squash(expr(pick(x, y)), df$env, df, is_top = FALSE)
expect_equal(expr, expr(data.table(x = x, y = y)))

# Top level pick works
expect_equal(group_by(df, pick(x, y))$groups, c("x", "y"))
})

# if_all ------------------------------------------------------------------

test_that("if_all collapses multiple expresions", {
Expand Down Expand Up @@ -268,7 +297,7 @@ test_that("if_all() can handle empty selection", {
)
})

test_that("across() .cols is evaluated in across()'s calling environment", {
test_that("if_all() .cols is evaluated in across()'s calling environment", {
dt <- lazy_dt(data.frame(y = 1))
fun <- function(x) capture_if_all(dt, if_all(all_of(x)))
expect_equal(
Expand Down

0 comments on commit 5750999

Please sign in to comment.