-
-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathPredictionData.R
63 lines (57 loc) · 2.22 KB
/
PredictionData.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#' @title Convert to PredictionData
#'
#' @name PredictionData
#' @rdname PredictionData
#'
#' @description
#' Objects of type `PredictionData` serve as a intermediate representation for objects of type [Prediction].
#' It is an internal data structure, implemented to optimize runtime and solve some issues emerging while serializing R6 objects.
#' End-users typically do not need to worry about the details, package developers are advised to continue reading for some technical information.
#'
#' Unlike most other \CRANpkg{mlr3} objects, `PredictionData` relies on the S3 class system.
#' The following operations must be supported to extend mlr3 for new task types:
#'
#' * [as_prediction_data()] converts objects to class `PredictionData`, e.g. objects of type [Prediction].
#' * [as_prediction()] converts objects to class [Prediction], e.g. objects of type `PredictionData`.
#' * `check_prediction_data()` is called on the return value of the predict method of a [Learner] to perform assertions and type conversions.
#' Returns an update object of class `PredictionData`.
#' * `is_missing_prediction_data()` is used for the fallback learner (see [Learner]) to impute missing predictions. Returns vector with row ids which need imputation.
#'
#'
NULL
new_prediction_data = function(li, task_type) {
li = discard(li, is.null)
class(li) = c(fget(mlr_reflections$task_types, task_type, "prediction_data", "type"), "PredictionData")
li
}
#' @rdname PredictionData
#'
#' @param task ([Task]).
#' @param learner ([Learner]).
#'
#' @export
create_empty_prediction_data = function(task, learner) {
UseMethod("create_empty_prediction_data")
}
#' @export
print.PredictionData = function(x, ...) {
catf("<%s:%i>", class(x)[1L], length(x$row_ids))
}
#' @rdname PredictionData
#' @param pdata ([PredictionData])\cr
#' Named list inheriting from `"PredictionData"`.
#' @export
check_prediction_data = function(pdata, ...) {
UseMethod("check_prediction_data")
}
#' @rdname PredictionData
#' @export
is_missing_prediction_data = function(pdata, ...) {
UseMethod("is_missing_prediction_data")
}
#' @rdname PredictionData
#' @template param_row_ids
#' @export
filter_prediction_data = function(pdata, row_ids, ...) {
UseMethod("filter_prediction_data")
}