From fe9cc96f7765862ed5db839591abb628ed067229 Mon Sep 17 00:00:00 2001 From: markfairbanks Date: Mon, 14 Nov 2022 10:53:23 -0700 Subject: [PATCH] Simplify and allow top level `pick()` --- R/tidyeval-across.R | 4 ++-- R/tidyeval.R | 34 ++++++++++++++-------------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/R/tidyeval-across.R b/R/tidyeval-across.R index ae2cf496..9958055b 100644 --- a/R/tidyeval-across.R +++ b/R/tidyeval-across.R @@ -3,10 +3,10 @@ capture_across <- function(data, x, j = TRUE) { dt_squash_across(get_expr(x), get_env(x), data, j) } -dt_squash_across <- function(call, env, data, j = j, is_top_across = TRUE) { +dt_squash_across <- function(call, env, data, j = j, is_top = TRUE) { call <- match.call(dplyr::across, call, expand.dots = FALSE, envir = env) out <- across_setup(data, call, env, allow_rename = TRUE, j = j, fn = "across()") - if (is_false(is_top_across)) { + if (is_false(is_top)) { out <- call2("data.table", !!!out) } out diff --git a/R/tidyeval.R b/R/tidyeval.R index 18b2db6d..8fb598bf 100644 --- a/R/tidyeval.R +++ b/R/tidyeval.R @@ -30,8 +30,7 @@ globalVariables(dt_funs) capture_dots <- function(.data, ..., .j = TRUE) { dots <- enquos(..., .named = .j) - top_across <- map(dots, quo_is_call, "across") - dots <- map2(dots, top_across, ~ dt_squash(.x, data = .data, j = .j, is_top_across = .y)) + dots <- map(dots, dt_squash, data = .data, j = .j) # Remove names from any list elements is_list <- map_lgl(dots, is.list) @@ -46,7 +45,7 @@ capture_new_vars <- function(.data, ...) { dots <- as.list(enquos(..., .named = TRUE)) for (i in seq_along(dots)) { dot <- dots[[i]] - dot <- dt_squash(dot, data = .data, is_top_across = quo_is_call(dot, "across")) + dot <- dt_squash(dot, data = .data) if (is.null(dot)) { dots[i] <- list(NULL) } else { @@ -69,7 +68,7 @@ capture_dot <- function(.data, x, j = TRUE) { } # squash quosures -dt_squash <- function(x, env, data, j = TRUE, is_top_across = TRUE) { +dt_squash <- function(x, env, data, j = TRUE, is_top = TRUE) { if (is_atomic(x) || is_null(x)) { x } else if (is_symbol(x)) { @@ -103,30 +102,25 @@ dt_squash <- function(x, env, data, j = TRUE, is_top_across = TRUE) { } } } else if (is_quosure(x)) { - dt_squash(get_expr(x), get_env(x), data, j = j, is_top_across) + dt_squash(get_expr(x), get_env(x), data, j = j, is_top) } else if (is_call(x, "if_any")) { dt_squash_if(x, env, data, j = j, reduce = "|") } else if (is_call(x, "if_all")) { dt_squash_if(x, env, data, j = j, reduce = "&") } else if (is_call(x, "across")) { - dt_squash_across(x, env, data, j = j, is_top_across) + dt_squash_across(x, env, data, j = j, is_top) } else if (is_call(x, "pick")) { - call <- call_match(x, pick, dots_expand = FALSE) - .cols <- call2("c", !!!call$...) - across_call <- call2("across", .cols) - dt_squash_across(across_call, env, data, j, is_top_across) + x[[1]] <- sym("c") + call <- call2("across", x) + dt_squash_across(call, env, data, j, is_top) } else if (is_call(x)) { - dt_squash_call(x, env, data, j = j, is_top_across) + dt_squash_call(x, env, data, j = j, is_top) } else { abort("Invalid input") } } -pick <- function(...) { - "yep" -} - -dt_squash_call <- function(x, env, data, j = TRUE, is_top_across = TRUE) { +dt_squash_call <- function(x, env, data, j = TRUE, is_top = TRUE) { if (is_mask_pronoun(x)) { var <- x[[3]] if (is_call(x, "[[")) { @@ -138,7 +132,7 @@ dt_squash_call <- function(x, env, data, j = TRUE, is_top_across = TRUE) { sym(paste0("..", var)) } } else if (is_call(x, c("coalesce", "replace_na"))) { - args <- lapply(x[-1], dt_squash, env, data, j, is_top_across) + args <- lapply(x[-1], dt_squash, env, data, j, is_top) call2("fcoalesce", !!!args) } else if (is_call(x, "case_when")) { # case_when(x ~ y) -> fcase(x, y) @@ -150,7 +144,7 @@ dt_squash_call <- function(x, env, data, j = TRUE, is_top_across = TRUE) { x[[3]] ) })) - args <- lapply(args, dt_squash, env = env, data = data, j = j, is_top_across) + args <- lapply(args, dt_squash, env = env, data = data, j = j, is_top) call2("fcase", !!!args) } else if (is_call(x, "cur_data")) { quote(.SD) @@ -175,7 +169,7 @@ dt_squash_call <- function(x, env, data, j = TRUE, is_top_across = TRUE) { } x[[1]] <- quote(fifelse) - x[-1] <- lapply(x[-1], dt_squash, env, data, j = j, is_top_across) + x[-1] <- lapply(x[-1], dt_squash, env, data, j = j, is_top) x } else if (is_call(x, c("lag", "lead"))) { if (is_call(x, "lag")) { @@ -248,7 +242,7 @@ dt_squash_call <- function(x, env, data, j = TRUE, is_top_across = TRUE) { } call } else { - x[-1] <- lapply(x[-1], dt_squash, env, data, j = j, is_top_across) + x[-1] <- lapply(x[-1], dt_squash, env, data, j = j, is_top = FALSE) x } }