Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Make apply_schema work within the arguments to a FunctionalTerm #117

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions docs/src/api.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
```@meta
CurrentModule = StatsModels
DocTestSetup = quote
DocTestSetup = quote
using StatsModels, Random, StatsBase
Random.seed!(2001)
end
Expand All @@ -23,13 +23,13 @@ modelcols
```@docs
FormulaTerm
InteractionTerm
FunctionTerm
```

### Placeholder terms

```@docs
Term
CallTerm
ConstantTerm
```

Expand All @@ -40,6 +40,7 @@ These are all generated by [`apply_schema`](@ref).
```@docs
ContinuousTerm
CategoricalTerm
FunctionCallTerm
InterceptTerm
MatrixTerm
collect_matrix_terms
Expand Down Expand Up @@ -76,8 +77,8 @@ StatsModels.drop_intercept
These are internal implementation details that are likely to change in the
near future. In particular, the `ModelFrame` and `ModelMatrix` wrappers are
dispreferred in favor of using terms directly, and can in most cases be
replaced by something like
replaced by something like

```julia
# instead of ModelMatrix(ModelFrame(f::FormulaTerm, data, model=MyModel))
sch = schema(f, data)
Expand Down
10 changes: 7 additions & 3 deletions src/StatsModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export
#re-export from StatsBase:
StatisticalModel,
RegressionModel,

@formula,
ModelFrame,
ModelMatrix,
Expand All @@ -38,7 +38,8 @@ export
InteractionTerm,
FormulaTerm,
InterceptTerm,
FunctionTerm,
CallTerm,
FunctionCallTerm,
MatrixTerm,

term,
Expand All @@ -50,13 +51,16 @@ export
width,
modelcols,
modelmatrix,
response
response,
protect,
unprotect

include("traits.jl")
include("contrasts.jl")
include("terms.jl")
include("schema.jl")
include("formula.jl")
include("protection.jl")
include("modelframe.jl")
include("statsmodel.jl")

Expand Down
8 changes: 3 additions & 5 deletions src/formula.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ end
"""
And1 <: FormulaRewrite

Remove numbers from interaction terms, so `1&x` becomes `&(x)` (which is later
Remove numbers from interaction terms, so `1&x` becomes `&(x)` (which is later
cleaned up by `EmptyAnd`).
"""
struct And1 <: FormulaRewrite end
Expand Down Expand Up @@ -188,7 +188,7 @@ function parse!(ex::Expr, rewrites::Vector)

# parse a copy of non-special calls
ex_parsed = ex.args[1] ∉ SPECIALS ? deepcopy(ex) : ex

# iterate over children, checking for special rules
child_idx = 2
while child_idx <= length(ex_parsed.args)
Expand All @@ -215,17 +215,15 @@ end

Capture a call to a function that is not part of the formula DSL. This replaces
`ex` with a call to [`capture_call`](@ref). `ex_parsed` is a copy of `ex` whose
arguments have been parsed according to the normal formula DSL rules and which
arguments have been parsed according to the normal formula DSL rules and which
will be passed as the final argument to `capture_call`.
"""
function capture_call_ex!(ex::Expr, ex_parsed::Expr)
symbols = extract_symbols(ex)
symbols_ex = Expr(:tuple, symbols...)
f_anon_ex = esc(Expr(:(->), symbols_ex, copy(ex)))
f_orig = ex.args[1]
ex.args = [:capture_call,
esc(f_orig),
f_anon_ex,
tuple(symbols...),
Meta.quot(deepcopy(ex)),
:[$(ex_parsed.args[2:end]...)]]
Expand Down
56 changes: 56 additions & 0 deletions src/protection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
abstract type PolyModel end

protect(x) = error("protect should only be used within a @formula")

"""
ProtectedCtx{OldCtx}
is a context type that is entered during the applictation of a schema to a
`ProtectedTerm`. It holds the `OldCtx`
"""
struct ProtectedCtx{OldCtx} end
function StatsModels.apply_schema(t::CallTerm{typeof(protect)}, sch, Mod::Type)
length(t.args_parsed) == 1 || throw(ArgumentError("`protect` only applies to a single term."))
parsed_term = t.args_parsed[1]
return apply_schema(parsed_term, sch, ProtectedCtx{Mod})
end


# Outside of a @formula unprotect strips the protect wrapper
unprotect(t::CallTerm{typeof(protect)}) = t.args_parsed[1]
unprotect(t) = t
function StatsModels.apply_schema(t::CallTerm{typeof(unprotect)}, sch, Mod::Type)
throw(DomainError("`unprotect` used outside a protected context."))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using a DomainError seems a little punny.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but for once ArgumentError seems wrong.
Could be a FormulaSyntaxError maybe?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or even ParseError or LoadError or whatever the error is that's thrown for invalid syntax? (but that's maybe too punny as well)

end
function StatsModels.apply_schema(t::CallTerm{typeof(unprotect)}, sch, Mod::Type{<:ProtectedCtx{OldCtx}}) where OldCtx
length(t.args_parsed) == 1 || throw(ArgumentError("`unprotect` only applies to a single term."))
parsed_term = t.args_parsed[1]
return apply_schema(parsed_term, sch, OldCtx)
end

## Defintion of how things act while protected:

# TODO: Transform * into FunctionTerms
# https://github.com/JuliaStats/StatsModels.jl/issues/119

apply_schema(t::ConstantTerm, schema, Mod::Type{<:ProtectedCtx}) = t

function direct_call(op, arg_terms::Tuple)
names = Tuple(termvars(arg_terms))
ex = Expr(:call, nameof(op), names...)
ct = CallTerm{typeof(op), names}(+, ex, t)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

return call_fallback_apply_schema(ct, schema, Mod)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here as below...why not just call apply_schema(ct, schema, Mod)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, probably not needed.

end
function apply_schema(t::TupleTerm, schema, Mod::Type{<:ProtectedCtx})
# TupleTerm is what is created by `x+y`, we need to turn that back into addition.
return direct_call(+, t)
end

function apply_schema(t::InteractionTerm, schema, Mod::Type{<:ProtectedCtx})
# InteractionTerm is what is created by `x&y`, we need to turn that back into bitwise and.
return direct_call(&, t.terms)
end


# Lets not do the below by default. Instead overloaded call terms should opt into the fallback for during ProtectedCtx
# that way we avoid and ambiguity error.
# apply_schema(ct::CallTerm, schema, Mod::Type{<:ProtectedCtx}) = call_fallback_apply_schema(ct, schema, Mod)
45 changes: 36 additions & 9 deletions src/schema.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@

terms(t::FormulaTerm) = union(terms(t.lhs), terms(t.rhs))
terms(t::InteractionTerm) = terms(t.terms)
terms(t::FunctionTerm{Fo,Fa,names}) where {Fo,Fa,names} = Term.(names)
terms(t::CallTerm) = Term.(termnames(t)) # TODO: This is wrong because termnames is wrong
terms(t::FunctionCallTerm) = t.terms
terms(t::AbstractTerm) = [t]
terms(t::MatrixTerm) = terms(t.terms)
terms(t::TupleTerm) = mapreduce(terms, union, t)


needs_schema(::AbstractTerm) = true
needs_schema(::ConstantTerm) = false
needs_schema(t) = false
Expand All @@ -26,7 +28,7 @@ needs_schema(t) = false
Compute all the invariants necessary to fit a model with `terms`. A schema is a dict that
maps `Term`s to their concrete instantiations (either `CategoricalTerm`s or
`ContinuousTerm`s. "Hints" may optionally be supplied in the form of a `Dict` mapping term
names (as `Symbol`s) to term or contrast types. If a hint is not provided for a variable,
names (as `Symbol`s) to term or contrast types. If a hint is not provided for a variable,
the appropriate term type will be guessed based on the data type from the data column: any
numeric data is assumed to be continuous, and any non-numeric data is assumed to be
categorical.
Expand Down Expand Up @@ -56,7 +58,7 @@ Dict{Any,Any} with 1 entry:
y => y
```

Note that concrete `ContinuousTerm` and `CategoricalTerm` and un-typed `Term`s print the
Note that concrete `ContinuousTerm` and `CategoricalTerm` and un-typed `Term`s print the
same in a container, but when printed alone are different:

```jldoctest 1
Expand Down Expand Up @@ -159,9 +161,9 @@ end
Return a new term that is the result of applying `schema` to term `t` with
destination model (type) `Mod`. If `Mod` is omitted, `Nothing` will be used.

When `t` is a `ContinuousTerm` or `CategoricalTerm` already, the term will be returned
unchanged _unless_ a matching term is found in the schema. This allows
selective re-setting of a schema to change the contrast coding or levels of a
When `t` is a `ContinuousTerm` or `CategoricalTerm` already, the term will be returned
unchanged _unless_ a matching term is found in the schema. This allows
selective re-setting of a schema to change the contrast coding or levels of a
categorical term, or to change a continuous term to categorical or vice versa.
"""
apply_schema(t, schema) = apply_schema(t, schema, Nothing)
Expand All @@ -180,6 +182,27 @@ apply_schema(t::Union{ContinuousTerm, CategoricalTerm}, schema, Mod::Type) =
get(schema, term(t.sym), t)
apply_schema(t::MatrixTerm, sch, Mod::Type) = MatrixTerm(apply_schema.(t.terms, Ref(sch), Mod))

function call_fallback_apply_schema(ct::CallTerm{F, Names}, schema, Mod) where {F, Names}
# First we apply schema to all terms inside the CallTerm arguments.
# Thus allowing them to have overloaded `apply_schema` behavour
terms = map(ct.args_parsed) do arg
apply_schema(arg, schema, Mod)
end
names = Symbol[Names...]
ft = FunctionCallTerm(ct.forig, names, terms, ct.exorig)

# Last, we apply the schema to the FunctionCallTerm, so it can have overloaded
# apply_schema behavour -- but the fallback it to leave it as is
# which will result in a FunctionCallTerm in the final formula
# so the function will be called in `modelcols`
return apply_schema(ft, schema, Mod)
end
apply_schema(ct::CallTerm, schema, Mod::Type) = call_fallback_apply_schema(ct, schema, Mod)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this indirection necessary? why not just put the body of call_fallback_apply_schema in this method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reasons of making it easy to write things that avoid ambiguities,
when defining protected methods.

Though I am not sure it is used in any code I have committed,
I may have backed the other stuff that used it out,
or not written it yet.


# To get back (approx) old behavour of FunctionTerm do
# apply_schema(ct::CallTerm, schema, Mod::Type) = call_fallback_apply_schema(ct, schema, ProtectedCtx{Mod})



# TODO: special case this for <:RegressionModel ?
function apply_schema(t::ConstantTerm, schema, Mod::Type)
Expand Down Expand Up @@ -234,7 +257,7 @@ function apply_schema(t::FormulaTerm, schema, Mod::Type{<:StatisticalModel})
end

# strategy is: apply schema, then "repair" if necessary (promote to full rank
# contrasts).
# contrasts).
#
# to know whether to repair, need to know context a term appears in. main
# effects occur in "own" context.
Expand Down Expand Up @@ -309,7 +332,7 @@ termsyms(t::InterceptTerm{true}) = Set(1)
termsyms(t::ConstantTerm) = Set((t.n,))
termsyms(t::Union{Term, CategoricalTerm, ContinuousTerm}) = Set([t.sym])
termsyms(t::InteractionTerm) = mapreduce(termsyms, union, t.terms)
termsyms(t::FunctionTerm) = Set([t.exorig])
termsyms(t::Union{CallTerm,FunctionCallTerm}) = Set(termnames(t))

symequal(t1::AbstractTerm, t2::AbstractTerm) = issetequal(termsyms(t1), termsyms(t2))

Expand All @@ -325,4 +348,8 @@ termvars(t::InteractionTerm) = mapreduce(termvars, union, t.terms)
termvars(t::TupleTerm) = mapreduce(termvars, union, t, init=Symbol[])
termvars(t::MatrixTerm) = termvars(t.terms)
termvars(t::FormulaTerm) = union(termvars(t.lhs), termvars(t.rhs))
termvars(t::FunctionTerm{Fo,Fa,names}) where {Fo,Fa,names} = collect(names)
termvars(t::Union{CallTerm,FunctionCallTerm}) = collect(termnames(t))


termnames(::CallTerm{<:Any, Names}) where Names = Names
termnames(ft::FunctionCallTerm) = ft.names
Loading