Skip to content

Commit

Permalink
Add Conn.SetAuthorizer method
Browse files Browse the repository at this point in the history
  • Loading branch information
zombiezen authored and AdamSLevy committed May 24, 2020
1 parent e49f25b commit 670c3fa
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 10 deletions.
132 changes: 132 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package sqlite

// #include <stdint.h>
// #include <sqlite3.h>
// extern int go_sqlite_auth_tramp(uintptr_t, int, char*, char*, char*, char*);
// static int c_auth_tramp(void *userData, int action, const char* arg1, const char* arg2, const char* db, const char* trigger) {
// return go_sqlite_auth_tramp((uintptr_t)userData, action, (char*)arg1, (char*)arg2, (char*)db, (char*)trigger);
// }
// static int sqlite3_go_set_authorizer(sqlite3* conn, uintptr_t id) {
// return sqlite3_set_authorizer(conn, c_auth_tramp, (void*)id);
// }
import "C"
import (
"errors"
"sync"
)

// An Authorizer is called during statement preparation to see whether an action
// is allowed by the application. See https://sqlite.org/c3ref/set_authorizer.html
type Authorizer interface {
Authorize(action OpType, info ActionInfo) AuthResult
}

// ActionInfo holds information about an action to be authorized.
type ActionInfo struct {
Arg1 string
Arg2 string
Database string
Trigger string
}

// SetAuthorizer registers an authorizer for the database connection.
// SetAuthorizer(nil) clears any authorizer previously set.
func (conn *Conn) SetAuthorizer(auth Authorizer) error {
if auth == nil {
if conn.authorizer == -1 {
return nil
}
conn.releaseAuthorizer()
res := C.sqlite3_set_authorizer(conn.conn, nil, nil)
return reserr("SetAuthorizer", "", "", res)
}

authFuncs.mu.Lock()
id := authFuncs.next
next := authFuncs.next + 1
if next < 0 {
authFuncs.mu.Unlock()
return errors.New("sqlite: authorizer function id overflow")
}
authFuncs.next = next
authFuncs.m[id] = auth
authFuncs.mu.Unlock()

res := C.sqlite3_go_set_authorizer(conn.conn, C.uintptr_t(id))
return reserr("SetAuthorizer", "", "", res)
}

func (conn *Conn) releaseAuthorizer() {
if conn.authorizer == -1 {
return
}
authFuncs.mu.Lock()
delete(authFuncs.m, conn.authorizer)
authFuncs.mu.Unlock()
conn.authorizer = -1
}

var authFuncs = struct {
mu sync.RWMutex
m map[int]Authorizer
next int
}{
m: make(map[int]Authorizer),
}

//export go_sqlite_auth_tramp
func go_sqlite_auth_tramp(id uintptr, action C.int, arg1, arg2 *C.char, db *C.char, trigger *C.char) C.int {
authFuncs.mu.RLock()
auth := authFuncs.m[int(id)]
authFuncs.mu.RUnlock()
info := ActionInfo{}
if arg1 != nil {
info.Arg1 = C.GoString(arg1)
}
if arg2 != nil {
info.Arg2 = C.GoString(arg2)
}
if db != nil {
info.Database = C.GoString(db)
}
if trigger != nil {
info.Trigger = C.GoString(trigger)
}
return C.int(auth.Authorize(OpType(action), info))
}

// AuthorizeFunc is a function that implements Authorizer.
type AuthorizeFunc func(action OpType, info ActionInfo) AuthResult

// Authorize calls f.
func (f AuthorizeFunc) Authorize(action OpType, info ActionInfo) AuthResult {
return f(action, info)
}

// AuthResult is the result of a call to an Authorizer. The zero value is
// SQLITE_OK.
type AuthResult int

// Possible return values of an Authorizer.
const (
// Cause the entire SQL statement to be rejected with an error.
SQLITE_DENY = AuthResult(C.SQLITE_DENY)
// Disallow the specific action but allow the SQL statement to continue to
// be compiled.
SQLITE_IGNORE = AuthResult(C.SQLITE_IGNORE)
)

// String returns the C constant name of the result.
func (result AuthResult) String() string {
switch result {
default:
var buf [20]byte
return "SQLITE_UNKNOWN_AUTH_RESULT(" + string(itoa(buf[:], int64(result))) + ")"
case AuthResult(C.SQLITE_OK):
return "SQLITE_OK"
case SQLITE_DENY:
return "SQLITE_DENY"
case SQLITE_IGNORE:
return "SQLITE_IGNORE"
}
}
54 changes: 54 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package sqlite_test

import (
"testing"

"crawshaw.io/sqlite"
)

func TestSetAuthorizer(t *testing.T) {
c, err := sqlite.OpenConn(":memory:", 0)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := c.Close(); err != nil {
t.Error(err)
}
}()

authResult := sqlite.AuthResult(0)
var lastAction sqlite.OpType
auth := sqlite.AuthorizeFunc(func(action sqlite.OpType, info sqlite.ActionInfo) sqlite.AuthResult {
lastAction = action
return authResult
})
c.SetAuthorizer(auth)

t.Run("Allowed", func(t *testing.T) {
authResult = 0
stmt, _, err := c.PrepareTransient("SELECT 1;")
if err != nil {
t.Fatal(err)
}
stmt.Finalize()
if lastAction != sqlite.SQLITE_SELECT {
t.Errorf("action = %q; want SQLITE_SELECT", lastAction)
}
})

t.Run("Denied", func(t *testing.T) {
authResult = sqlite.SQLITE_DENY
stmt, _, err := c.PrepareTransient("SELECT 1;")
if err == nil {
stmt.Finalize()
t.Fatal("PrepareTransient did not return an error")
}
if got, want := sqlite.ErrCode(err), sqlite.SQLITE_AUTH; got != want {
t.Errorf("sqlite.ErrCode(err) = %v; want %v", got, want)
}
if lastAction != sqlite.SQLITE_SELECT {
t.Errorf("action = %q; want SQLITE_SELECT", lastAction)
}
})
}
106 changes: 101 additions & 5 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -406,25 +406,121 @@ func (iter ChangesetIter) PK() ([]bool, error) {
return cols, nil
}

// OpType is an enumeration of SQLite statements. Used for authorization and
// changeset details.
type OpType int

// Operation types
const (
SQLITE_INSERT = OpType(C.SQLITE_INSERT)
SQLITE_DELETE = OpType(C.SQLITE_DELETE)
SQLITE_UPDATE = OpType(C.SQLITE_UPDATE)
SQLITE_CREATE_INDEX = OpType(C.SQLITE_CREATE_INDEX)
SQLITE_CREATE_TABLE = OpType(C.SQLITE_CREATE_TABLE)
SQLITE_CREATE_TEMP_INDEX = OpType(C.SQLITE_CREATE_TEMP_INDEX)
SQLITE_CREATE_TEMP_TABLE = OpType(C.SQLITE_CREATE_TEMP_TABLE)
SQLITE_CREATE_TEMP_TRIGGER = OpType(C.SQLITE_CREATE_TEMP_TRIGGER)
SQLITE_CREATE_TEMP_VIEW = OpType(C.SQLITE_CREATE_TEMP_VIEW)
SQLITE_CREATE_TRIGGER = OpType(C.SQLITE_CREATE_TRIGGER)
SQLITE_CREATE_VIEW = OpType(C.SQLITE_CREATE_VIEW)
SQLITE_DELETE = OpType(C.SQLITE_DELETE)
SQLITE_DROP_INDEX = OpType(C.SQLITE_DROP_INDEX)
SQLITE_DROP_TABLE = OpType(C.SQLITE_DROP_TABLE)
SQLITE_DROP_TEMP_INDEX = OpType(C.SQLITE_DROP_TEMP_INDEX)
SQLITE_DROP_TEMP_TABLE = OpType(C.SQLITE_DROP_TEMP_TABLE)
SQLITE_DROP_TEMP_TRIGGER = OpType(C.SQLITE_DROP_TEMP_TRIGGER)
SQLITE_DROP_TEMP_VIEW = OpType(C.SQLITE_DROP_TEMP_VIEW)
SQLITE_DROP_TRIGGER = OpType(C.SQLITE_DROP_TRIGGER)
SQLITE_DROP_VIEW = OpType(C.SQLITE_DROP_VIEW)
SQLITE_INSERT = OpType(C.SQLITE_INSERT)
SQLITE_PRAGMA = OpType(C.SQLITE_PRAGMA)
SQLITE_READ = OpType(C.SQLITE_READ)
SQLITE_SELECT = OpType(C.SQLITE_SELECT)
SQLITE_TRANSACTION = OpType(C.SQLITE_TRANSACTION)
SQLITE_UPDATE = OpType(C.SQLITE_UPDATE)
SQLITE_ATTACH = OpType(C.SQLITE_ATTACH)
SQLITE_DETACH = OpType(C.SQLITE_DETACH)
SQLITE_ALTER_TABLE = OpType(C.SQLITE_ALTER_TABLE)
SQLITE_REINDEX = OpType(C.SQLITE_REINDEX)
SQLITE_ANALYZE = OpType(C.SQLITE_ANALYZE)
SQLITE_CREATE_VTABLE = OpType(C.SQLITE_CREATE_VTABLE)
SQLITE_DROP_VTABLE = OpType(C.SQLITE_DROP_VTABLE)
SQLITE_FUNCTION = OpType(C.SQLITE_FUNCTION)
SQLITE_SAVEPOINT = OpType(C.SQLITE_SAVEPOINT)
SQLITE_COPY = OpType(C.SQLITE_COPY)
SQLITE_RECURSIVE = OpType(C.SQLITE_RECURSIVE)
)

func (opType OpType) String() string {
switch opType {
default:
var buf [20]byte
return "SQLITE_UNKNOWN_OP_TYPE(" + string(itoa(buf[:], int64(opType))) + ")"
case SQLITE_INSERT:
return "SQLITE_INSERT"
case SQLITE_CREATE_INDEX:
return "SQLITE_CREATE_INDEX"
case SQLITE_CREATE_TABLE:
return "SQLITE_CREATE_TABLE"
case SQLITE_CREATE_TEMP_INDEX:
return "SQLITE_CREATE_TEMP_INDEX"
case SQLITE_CREATE_TEMP_TABLE:
return "SQLITE_CREATE_TEMP_TABLE"
case SQLITE_CREATE_TEMP_TRIGGER:
return "SQLITE_CREATE_TEMP_TRIGGER"
case SQLITE_CREATE_TEMP_VIEW:
return "SQLITE_CREATE_TEMP_VIEW"
case SQLITE_CREATE_TRIGGER:
return "SQLITE_CREATE_TRIGGER"
case SQLITE_CREATE_VIEW:
return "SQLITE_CREATE_VIEW"
case SQLITE_DELETE:
return "SQLITE_DELETE"
case SQLITE_DROP_INDEX:
return "SQLITE_DROP_INDEX"
case SQLITE_DROP_TABLE:
return "SQLITE_DROP_TABLE"
case SQLITE_DROP_TEMP_INDEX:
return "SQLITE_DROP_TEMP_INDEX"
case SQLITE_DROP_TEMP_TABLE:
return "SQLITE_DROP_TEMP_TABLE"
case SQLITE_DROP_TEMP_TRIGGER:
return "SQLITE_DROP_TEMP_TRIGGER"
case SQLITE_DROP_TEMP_VIEW:
return "SQLITE_DROP_TEMP_VIEW"
case SQLITE_DROP_TRIGGER:
return "SQLITE_DROP_TRIGGER"
case SQLITE_DROP_VIEW:
return "SQLITE_DROP_VIEW"
case SQLITE_INSERT:
return "SQLITE_INSERT"
case SQLITE_PRAGMA:
return "SQLITE_PRAGMA"
case SQLITE_READ:
return "SQLITE_READ"
case SQLITE_SELECT:
return "SQLITE_SELECT"
case SQLITE_TRANSACTION:
return "SQLITE_TRANSACTION"
case SQLITE_UPDATE:
return "SQLITE_UPDATE"
case SQLITE_ATTACH:
return "SQLITE_ATTACH"
case SQLITE_DETACH:
return "SQLITE_DETACH"
case SQLITE_ALTER_TABLE:
return "SQLITE_ALTER_TABLE"
case SQLITE_REINDEX:
return "SQLITE_REINDEX"
case SQLITE_ANALYZE:
return "SQLITE_ANALYZE"
case SQLITE_CREATE_VTABLE:
return "SQLITE_CREATE_VTABLE"
case SQLITE_DROP_VTABLE:
return "SQLITE_DROP_VTABLE"
case SQLITE_FUNCTION:
return "SQLITE_FUNCTION"
case SQLITE_SAVEPOINT:
return "SQLITE_SAVEPOINT"
case SQLITE_COPY:
return "SQLITE_COPY"
case SQLITE_RECURSIVE:
return "SQLITE_RECURSIVE"
}
}

Expand Down
13 changes: 8 additions & 5 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,11 @@ import (
//
// A Conn can only be used by goroutine at a time.
type Conn struct {
conn *C.sqlite3
stmts map[string]*Stmt // query -> prepared statement
closed bool
count int // shared variable to help the race detector find Conn misuse
conn *C.sqlite3
stmts map[string]*Stmt // query -> prepared statement
authorizer int // authorizer ID or -1
closed bool
count int // shared variable to help the race detector find Conn misuse

cancelCh chan struct{}
tracer Tracer
Expand Down Expand Up @@ -135,7 +136,8 @@ func openConn(path string, flags OpenFlags) (*Conn, error) {
flags = SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_WAL | SQLITE_OPEN_URI | SQLITE_OPEN_NOMUTEX
}
conn := &Conn{
stmts: make(map[string]*Stmt),
stmts: make(map[string]*Stmt),
authorizer: -1,
// A pointer to unlockNote is retained by C,
// so we allocate it on the C heap.
unlockNote: C.unlock_note_alloc(),
Expand Down Expand Up @@ -201,6 +203,7 @@ func (conn *Conn) Close() error {
res := C.sqlite3_close(conn.conn)
C.unlock_note_free(conn.unlockNote)
conn.unlockNote = nil
conn.releaseAuthorizer()
return reserr("Conn.Close", "", "", res)
}

Expand Down

0 comments on commit 670c3fa

Please sign in to comment.