Skip to content

Commit

Permalink
Implement across() (#170)
Browse files Browse the repository at this point in the history
Fixes #154
  • Loading branch information
hadley authored Jan 29, 2021
1 parent 6482016 commit c10a2d2
Show file tree
Hide file tree
Showing 15 changed files with 303 additions and 82 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Imports:
crayon,
data.table (>= 1.12.4),
dplyr (>= 1.0.0),
glue,
lifecycle,
rlang,
tibble,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# dtplyr (development version)

* dtplyr can now translate `across()` (#154).

* Objects now printing grouping if present.

* dtplyr now supports `group_map()` and `group_walk()` (#108).

* dtplyr now supports `relocate()` (@smingerson, #162).
Expand Down
1 change: 1 addition & 0 deletions R/step-group.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ add_grouping_param <- function(call, step) {
#' @export
group_by.dtplyr_step <- function(.data, ..., .add = FALSE, add = deprecated(), arrange = TRUE) {
dots <- capture_dots(.data, ...)
dots <- exprs_auto_name(dots)

if (lifecycle::is_present(add)) {
lifecycle::deprecate_warn(
Expand Down
5 changes: 4 additions & 1 deletion R/step.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ compute.dtplyr_step <- function(x, ...) {

#' @rdname collect
#' @export
#' @param keep.rownames Ignored as dplyr never preseres rownames.
#' @param keep.rownames Ignored as dplyr never preserves rownames.
as.data.table.dtplyr_step <- function(x, keep.rownames = FALSE, ...) {
dt_eval(x)[]
}
Expand Down Expand Up @@ -145,6 +145,9 @@ pull.dtplyr_step <- function(.data, var = -1) {
print.dtplyr_step <- function(x, ...) {
cat_line(crayon::bold("Source: "), "local data table ", dplyr::dim_desc(x))
cat_line(crayon::bold("Call: "), expr_text(dt_call(x)))
if (length(x$groups) > 0) {
cat_line(crayon::bold("Groups: "), paste(x$groups, collapse = ", "))
}
cat_line()
cat_line(format(as_tibble(head(x)))[-1]) # Hack to remove "A tibble" line
cat_line()
Expand Down
83 changes: 83 additions & 0 deletions R/tidyeval-across.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
capture_across <- function(vars, x, j = TRUE) {
x <- enquo(x)
dt_squash_across(get_expr(x), get_env(x), vars, j)
}

dt_squash_across <- function(call, env, vars, j = j) {
call <- match.call(dplyr::across, call, expand.dots = FALSE, envir = env)

tbl <- as_tibble(rep_named(vars, list(logical())))
locs <- tidyselect::eval_select(call$.cols, tbl, allow_rename = FALSE)
cols <- syms(vars)[locs]

funs <- across_funs(call$.fns, env, vars, j = j)

# Generate grid of expressions
out <- vector("list", length(cols) * length(funs))
k <- 1
for (i in seq_along(cols)) {
for (j in seq_along(funs)) {
out[[k]] <- exec(funs[[j]], cols[[i]], !!!call$...)
k <- k + 1
}
}

.names <- eval(call$.names, env)
if (!is.null(call$.fns)) {
names(out) <- across_names(vars[locs], names(funs), .names, env)
}
out
}

across_funs <- function(funs, env, vars, j = TRUE) {
if (is.null(funs)) {
list(function(x, ...) x)
} else if (is_symbol(funs)) {
set_names(list(across_fun(funs, env, vars, j = j)), as.character(funs))
} else if (is.character(funs)) {
names(funs)[names2(funs) == ""] <- funs
lapply(funs, across_fun, env, vars, j = j)
} else if (is_call(funs, "~")) {
set_names(list(across_fun(funs, env, vars, j = j)), expr_name(f_rhs(funs)))
} else if (is_call(funs, "list")) {
args <- rlang::exprs_auto_name(funs[-1])
lapply(args, across_fun, env, vars, j = j)
} else if (!is.null(env)) {
# Try evaluating once, just in case
funs <- eval(funs, env)
across_funs(funs, NULL)
} else {
abort("`.fns` argument to dtplyr::across() must be a NULL, a function name, formula, or list")
}
}

across_fun <- function(fun, env, vars, j = TRUE) {
if (is_symbol(fun) || is_string(fun)) {
function(x, ...) call2(fun, x, ...)
} else if (is_call(fun, "~")) {
call <- f_rhs(fun)
call <- replace_dot(call, quote(!!.x))
call <- dt_squash_call(call, env, vars, j = TRUE)

function(x, ...) expr_interp(call, child_env(emptyenv(), .x = x))
} else {
abort(c(
".fns argument to dtplyr::across() contain a function name or a formula",
x = paste0("Problem with ", expr_deparse(fun))
))
}
}

across_names <- function(cols, funs, names = NULL, env = parent.frame()) {
if (length(funs) == 1) {
names <- names %||% "{.col}"
} else {
names <- names %||% "{.col}_{.fn}"
}

glue_env <- child_env(env,
.col = rep(cols, each = length(funs)),
.fn = rep(funs %||% seq_along(funs), length(cols))
)
glue::glue(names, .envir = glue_env)
}
101 changes: 56 additions & 45 deletions R/tidyeval.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ add_dt_wrappers <- function(env) {
capture_dots <- function(.data, ..., .j = TRUE) {
dots <- enquos(..., .named = .j)
dots <- lapply(dots, dt_squash, vars = .data$vars, j = .j)
dots

# Remove names from any list elements
is_list <- vapply(dots, is.list, logical(1))
names(dots)[is_list] <- ""

unlist(dots, recursive = FALSE)
}

capture_dot <- function(.data, x, j = TRUE) {
Expand Down Expand Up @@ -67,49 +72,55 @@ dt_squash <- function(x, env, vars, j = TRUE) {
}
} else if (is_quosure(x)) {
dt_squash(get_expr(x), get_env(x), vars = vars, j = j)
} else if (is_call(x, "across")) {
dt_squash_across(x, env, vars, j = j)
} else if (is_call(x)) {
if (is_mask_pronoun(x)) {
var <- x[[3]]
if (is_call(x, "[[")) {
var <- sym(eval(var, env))
}
dt_squash_call(x, env, vars, j = j)
} else {
abort("Invalid input")
}
}

if (is_symbol(x[[2]], ".data")) {
var
} else if (is_symbol(x[[2]], ".env")) {
sym(paste0("..", var))
}
} else if (is_call(x, "n", n = 0)) {
quote(.N)
} else if (is_call(x, "row_number", n = 0)) {
quote(seq_len(.N))
} else if (is_call(x, "row_number", n = 1)) {
arg <- dt_squash(x[[2]], vars = vars, env = env, j = j)
expr(frank(!!arg, ties.method = "first", na.last = "keep"))
} else if (is_call(x, "if_else")) {
x[[1L]] <- quote(fifelse)
x
} else if (is_call(x, 'coalesce')) {
x[[1L]] <- quote(fcoalesce)
x
} else if (is_call(x, "cur_data")) {
quote(.SD)
} else if (is_call(x, "cur_data_all")) {
abort("`cur_data_all()` is not available in dtplyr")
} else if (is_call(x, "cur_group")) {
quote(.BY)
} else if (is_call(x, "cur_group_id")) {
quote(.GRP)
} else if (is_call(x, "cur_group_rows")) {
quote(.I)
} else if (is.function(x[[1]]) || is_call(x, "function")) {
simplify_function_call(x, env, vars = vars, j = j)
} else {
x[-1] <- lapply(x[-1], dt_squash, vars = vars, env = env, j = j)
x
dt_squash_call <- function(x, env, vars, j = TRUE) {
if (is_mask_pronoun(x)) {
var <- x[[3]]
if (is_call(x, "[[")) {
var <- sym(eval(var, env))
}

if (is_symbol(x[[2]], ".data")) {
var
} else if (is_symbol(x[[2]], ".env")) {
sym(paste0("..", var))
}
} else if (is_call(x, "n", n = 0)) {
quote(.N)
} else if (is_call(x, "row_number", n = 0)) {
quote(seq_len(.N))
} else if (is_call(x, "row_number", n = 1)) {
arg <- dt_squash(x[[2]], vars = vars, env = env, j = j)
expr(frank(!!arg, ties.method = "first", na.last = "keep"))
} else if (is_call(x, "if_else")) {
x[[1L]] <- quote(fifelse)
x
} else if (is_call(x, 'coalesce')) {
x[[1L]] <- quote(fcoalesce)
x
} else if (is_call(x, "cur_data")) {
quote(.SD)
} else if (is_call(x, "cur_data_all")) {
abort("`cur_data_all()` is not available in dtplyr")
} else if (is_call(x, "cur_group")) {
quote(.BY)
} else if (is_call(x, "cur_group_id")) {
quote(.GRP)
} else if (is_call(x, "cur_group_rows")) {
quote(.I)
} else if (is.function(x[[1]]) || is_call(x, "function")) {
simplify_function_call(x, env, vars = vars, j = j)
} else {
abort("Invalid input")
x[-1] <- lapply(x[-1], dt_squash, vars = vars, env = env, j = j)
x
}
}

Expand All @@ -133,7 +144,7 @@ is_global <- function(env) {
simplify_function_call <- function(x, env, vars, j = TRUE) {
if (inherits(x[[1]], "inline_colwise_function")) {
dot_var <- vars[[attr(x, "position")]]
out <- replace_dot(attr(x[[1]], "formula")[[2]], dot_var)
out <- replace_dot(attr(x[[1]], "formula")[[2]], sym(dot_var))
dt_squash(out, env, vars = vars, j = j)
} else {
name <- fun_name(x[[1]])
Expand All @@ -147,11 +158,11 @@ simplify_function_call <- function(x, env, vars, j = TRUE) {
}
}

replace_dot <- function(call, var) {
if (is_symbol(call, ".")) {
sym(var)
replace_dot <- function(call, sym) {
if (is_symbol(call, ".") || is_symbol(call, ".x")) {
sym
} else if (is_call(call)) {
call[] <- lapply(call, replace_dot, var = var)
call[] <- lapply(call, replace_dot, sym)
call
} else {
call
Expand Down
2 changes: 1 addition & 1 deletion man/collect.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 38 additions & 0 deletions tests/testthat/_snaps/step.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# has useful display methods

Code
dt <- lazy_dt(mtcars, "DT")
Code
dt
Output
Source: local data table [32 x 11]
Call: DT
mpg cyl disp hp drat wt qsec vs am gear carb
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 21 6 160 110 3.9 2.62 16.5 0 1 4 4
2 21 6 160 110 3.9 2.88 17.0 0 1 4 4
3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1
4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1
5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2
6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1
# Use as.data.table()/as.data.frame()/as_tibble() to access results
Code
dt %>% group_by(vs, am)
Output
Source: local data table [?? x 11]
Call: DT
Groups: vs, am
mpg cyl disp hp drat wt qsec vs am gear carb
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 21 6 160 110 3.9 2.62 16.5 0 1 4 4
2 21 6 160 110 3.9 2.88 17.0 0 1 4 4
3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1
4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1
5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2
6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1
# Use as.data.table()/as.data.frame()/as_tibble() to access results

12 changes: 12 additions & 0 deletions tests/testthat/_snaps/tidyeval-across.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# across() gives informative errors

Code
capture_across(letters, across(a, 1))
Error <rlang_error>
`.fns` argument to dtplyr::across() must be a NULL, a function name, formula, or list
Code
capture_across(letters, across(a, list(1)))
Error <rlang_error>
.fns argument to dtplyr::across() contain a function name or a formula
x Problem with 1

6 changes: 6 additions & 0 deletions tests/testthat/test-step-group.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@ test_that("grouping and ungrouping adjust groups field", {

expect_equal(dt %>% .$groups, character())
expect_equal(dt %>% group_by(x) %>% .$groups, "x")
expect_equal(dt %>% group_by(a = x) %>% .$groups, "a")
expect_equal(dt %>% group_by(x) %>% group_by(y) %>% .$groups, "y")
expect_equal(dt %>% group_by(x) %>% ungroup() %>% .$groups, character())
})

test_that("can use across", {
dt <- lazy_dt(data.frame(x = 1:3, y = 1:3))
expect_equal(dt %>% group_by(across(everything())) %>% .$groups, c("x", "y"))
})

test_that("can add groups if requested", {
dt <- lazy_dt(data.frame(x = 1:3, y = 1:3), "DT")
expect_equal(
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-step-mutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ test_that("mutate generates compound expression if needed", {
)
})

test_that("can use across", {
dt <- lazy_dt(data.table(x = 1, y = 2), "DT")

expect_equal(
dt %>% mutate(across(everything(), ~ . + 1)) %>% show_query(),
expr(copy(DT)[, `:=`(x = x + 1, y = y + 1)])
)
})

test_that("vars set correctly", {
dt <- lazy_dt(data.frame(x = 1:3, y = 1:3))
expect_equal(dt %>% mutate(z = 1) %>% .$vars, c("x", "y", "z"))
Expand Down
30 changes: 0 additions & 30 deletions tests/testthat/test-step-print.txt

This file was deleted.

Loading

0 comments on commit c10a2d2

Please sign in to comment.