-
Notifications
You must be signed in to change notification settings - Fork 318
/
Copy pathloader.go
288 lines (262 loc) · 8.07 KB
/
loader.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
// Package loader loads query and schema information from mysql, oracle,
// postgres, sqlite3 and sqlserver databases.
package loader
import (
"context"
"database/sql"
"fmt"
"regexp"
"sort"
"strings"
"github.com/kenshaw/snaker"
"github.com/xo/xo/models"
xo "github.com/xo/xo/types"
)
// loaders are registered database loaders.
var loaders = make(map[string]Loader)
// Register registers a database loader.
func Register(typ string, loader Loader) {
loaders[typ] = loader
}
// Flags returns the additional flags for the loaders.
//
// These should be added to the invocation context for any call to a loader
// func.
func Flags() []xo.FlagSet {
var types []string
for typ := range loaders {
types = append(types, typ)
}
sort.Strings(types)
var flags []xo.FlagSet
for _, typ := range types {
l := loaders[typ]
if l.Flags == nil {
continue
}
for _, flag := range l.Flags() {
flags = append(flags, xo.FlagSet{
Type: typ,
Name: string(flag.ContextKey),
Flag: flag,
})
}
}
return flags
}
// Loader loads type information from a database.
type Loader struct {
Type string
Mask string
Flags func() []xo.Flag
Schema func(context.Context, models.DB) (string, error)
Enums func(context.Context, models.DB, string) ([]*models.Enum, error)
EnumValues func(context.Context, models.DB, string, string) ([]*models.EnumValue, error)
Procs func(context.Context, models.DB, string) ([]*models.Proc, error)
ProcParams func(context.Context, models.DB, string, string) ([]*models.ProcParam, error)
Tables func(context.Context, models.DB, string, string) ([]*models.Table, error)
TableColumns func(context.Context, models.DB, string, string) ([]*models.Column, error)
TableSequences func(context.Context, models.DB, string, string) ([]*models.Sequence, error)
TableForeignKeys func(context.Context, models.DB, string, string) ([]*models.ForeignKey, error)
TableIndexes func(context.Context, models.DB, string, string) ([]*models.Index, error)
IndexColumns func(context.Context, models.DB, string, string, string) ([]*models.IndexColumn, error)
ViewCreate func(context.Context, models.DB, string, string, []string) (sql.Result, error)
ViewSchema func(context.Context, models.DB, string) (string, error)
ViewTruncate func(context.Context, models.DB, string, string) (sql.Result, error)
ViewDrop func(context.Context, models.DB, string, string) (sql.Result, error)
ViewStrip func([]string, []string) ([]string, []string, []string, error)
}
// get retrieves the database connection, loader, and schema name from the
// context.
func get(ctx context.Context) (*sql.DB, *Loader, string, error) {
typ, _ := ctx.Value(xo.DriverKey).(string)
l, ok := loaders[typ]
if !ok {
return nil, nil, "", fmt.Errorf("no database loader available for %q", typ)
}
db, _ := ctx.Value(xo.DbKey).(*sql.DB)
schema, _ := ctx.Value(xo.SchemaKey).(string)
return db, &l, schema, nil
}
// NthParam returns a 0-based func to generate the nth param placeholder for
// database queries.
func NthParam(ctx context.Context) (func(int) string, error) {
_, l, _, err := get(ctx)
if err != nil {
return nil, err
}
mask := "?"
if l.Mask != "" {
mask = l.Mask
}
if !strings.Contains(mask, "%d") {
return func(int) string {
return mask
}, nil
}
return func(i int) string {
return fmt.Sprintf(mask, i+1)
}, nil
}
// Schema loads the active schema name from the context.
func Schema(ctx context.Context) (string, error) {
db, l, _, err := get(ctx)
if err != nil {
return "", err
}
return l.Schema(ctx, db)
}
// Enums returns the database enums.
func Enums(ctx context.Context) ([]*models.Enum, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
if l.Enums != nil {
return l.Enums(ctx, db, schema)
}
return nil, nil
}
// EnumValues returns the database enum values.
func EnumValues(ctx context.Context, enum string) ([]*models.EnumValue, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
return l.EnumValues(ctx, db, schema, enum)
}
// Procs returns the database procs.
func Procs(ctx context.Context) ([]*models.Proc, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
if l.Procs != nil {
return l.Procs(ctx, db, schema)
}
return nil, nil
}
// ProcParams returns the database proc params.
func ProcParams(ctx context.Context, id string) ([]*models.ProcParam, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
if l.ProcParams != nil {
return l.ProcParams(ctx, db, schema, id)
}
return nil, nil
}
// Tables returns the database tables.
func Tables(ctx context.Context, typ string) ([]*models.Table, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
return l.Tables(ctx, db, schema, typ)
}
// TableColumns returns the database table columns.
func TableColumns(ctx context.Context, table string) ([]*models.Column, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
return l.TableColumns(ctx, db, schema, table)
}
// TableSequences returns the database table sequences.
func TableSequences(ctx context.Context, table string) ([]*models.Sequence, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
return l.TableSequences(ctx, db, schema, table)
}
// TableForeignKeys returns the database table foreign keys.
func TableForeignKeys(ctx context.Context, table string) ([]*models.ForeignKey, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
return l.TableForeignKeys(ctx, db, schema, table)
}
// TableIndexes returns the database table indexes.
func TableIndexes(ctx context.Context, table string) ([]*models.Index, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
return l.TableIndexes(ctx, db, schema, table)
}
// IndexColumns returns the database index columns.
func IndexColumns(ctx context.Context, table, index string) ([]*models.IndexColumn, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
return l.IndexColumns(ctx, db, schema, table, index)
}
// ViewCreate creates a introspection view of a query.
func ViewCreate(ctx context.Context, id string, query []string) (sql.Result, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
return l.ViewCreate(ctx, db, schema, id, query)
}
// ViewSchema returns the schema that the introspection view was created in.
func ViewSchema(ctx context.Context, id string) (string, error) {
db, l, _, err := get(ctx)
if err != nil {
return "", err
}
if l.ViewSchema != nil {
return l.ViewSchema(ctx, db, id)
}
return "", nil
}
// ViewTruncate truncates the introspection view.
func ViewTruncate(ctx context.Context, id string) (sql.Result, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
if l.ViewTruncate != nil {
return l.ViewTruncate(ctx, db, schema, id)
}
return nil, nil
}
// ViewDrop drops the introspection view.
func ViewDrop(ctx context.Context, id string) (sql.Result, error) {
db, l, schema, err := get(ctx)
if err != nil {
return nil, err
}
return l.ViewDrop(ctx, db, schema, id)
}
// ViewStrip post processes the query and inspected query, altering as
// necessary and building a set of comments for the query.
func ViewStrip(ctx context.Context, query, inspect []string) ([]string, []string, []string, error) {
_, l, _, err := get(ctx)
if err != nil {
return nil, nil, nil, err
}
if l.ViewStrip != nil {
return l.ViewStrip(query, inspect)
}
return query, inspect, make([]string, len(query)), nil
}
// schemaType returns Go type and zero for a type, removing a "<schema>."
// prefix when the type is determined to be in the same package.
func schemaType(typ string, nullable bool, schema string) (string, string) {
if strings.HasPrefix(typ, schema+".") {
// in the same schema, so chop off
typ = typ[len(schema)+1:]
}
if nullable {
typ = "null_" + typ
}
s := snaker.SnakeToCamelIdentifier(typ)
return s, s + "{}"
}
// intRE matches Go int types.
var intRE = regexp.MustCompile(`^int(8|16|32|64)?$`)