Skip to content

Commit

Permalink
Overhaul of file: and opaque path handling
Browse files Browse the repository at this point in the history
  • Loading branch information
kenshaw committed Nov 13, 2023
1 parent e00fe6c commit 610d3ae
Show file tree
Hide file tree
Showing 5 changed files with 1,018 additions and 286 deletions.
194 changes: 153 additions & 41 deletions dburl.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ package dburl

import (
"database/sql"
"io/fs"
"net/url"
"os"
"path"
"path/filepath"
"runtime"
"strings"
)

Expand Down Expand Up @@ -78,6 +82,9 @@ func Parse(urlstr string) (*URL, error) {
case err != nil:
return nil, err
case v.Scheme == "":
if typ, err := SchemeType(urlstr); err == nil {
return Parse(typ + ":" + urlstr)
}
return nil, ErrInvalidDatabaseScheme
}
// create url
Expand All @@ -95,44 +102,31 @@ func Parse(urlstr string) (*URL, error) {
}
// get dsn generator
scheme, ok := schemeMap[u.Scheme]
if !ok {
switch {
case !ok:
return nil, ErrUnknownDatabaseScheme
}
// load real scheme for file:
if scheme.Driver == "file" {
typ, err := SchemeType(u.Opaque)
if err == nil {
if s, ok := schemeMap[typ]; ok {
scheme = s
}
}
}
// if scheme does not understand opaque URLs, retry parsing after building
// fully qualified URL
if !scheme.Opaque && u.Opaque != "" {
var q string
if u.RawQuery != "" {
q = "?" + u.RawQuery
}
var f string
if u.Fragment != "" {
f = "#" + u.Fragment
case scheme.Driver == "file" && u.Opaque != "":
// determine scheme for file
if typ, err := SchemeType(u.Opaque); err == nil {
return Parse(typ + ":" + buildOpaque(u))
}
return Parse(u.OriginalScheme + "://" + u.Opaque + q + f)
}
if scheme.Opaque && u.Opaque == "" {
return nil, ErrUnknownFileExtension
case !scheme.Opaque && u.Opaque != "":
// if scheme does not understand opaque URLs, retry parsing after
// building fully qualified URL
return Parse(u.OriginalScheme + "://" + buildOpaque(u))
case scheme.Opaque && u.Opaque == "":
// force Opaque
u.Opaque, u.Host, u.Path, u.RawPath = u.Host+u.Path, "", "", ""
} else if u.Host == "." || (u.Host == "" && strings.TrimPrefix(u.Path, "/") != "") {
case u.Host == ".", u.Host == "" && strings.TrimPrefix(u.Path, "/") != "":
// force unix proto
u.Transport = "unix"
}
// check proto
// check transport
if checkTransport || u.Transport != "tcp" {
if scheme.Transport == TransportNone {
return nil, ErrInvalidTransportProtocol
}
switch {
case scheme.Transport == TransportNone:
return nil, ErrInvalidTransportProtocol
case scheme.Transport&TransportAny != 0 && u.Transport != "",
scheme.Transport&TransportTCP != 0 && u.Transport == "tcp",
scheme.Transport&TransportUDP != 0 && u.Transport == "udp",
Expand Down Expand Up @@ -240,23 +234,36 @@ func (u *URL) Normalize(sep, empty string, cut int) string {
return strings.Join(s, sep)
}

// SchemeType returns the scheme type for a file on disk.
// SchemeType returns the scheme type for a path.
func SchemeType(name string) (string, error) {
f, err := os.OpenFile(name, os.O_RDONLY, 0)
if err != nil {
return "", err
// try to resolve the path on unix systems
if runtime.GOOS != "windows" /*&& !mode(name).IsRegular()*/ {
if typ, ok := resolveType(name); ok {
return typ, nil
}
}
defer f.Close()
buf := make([]byte, 128)
if _, err := f.Read(buf); err != nil {
return "", err
if f, err := OpenFile(name); err == nil {
defer f.Close()
// file exists, match header
buf := make([]byte, 64)
if n, _ := f.Read(buf); n == 0 {
return "sqlite3", nil
}
for _, typ := range fileTypes {
if typ.f(buf) {
return typ.driver, nil
}
}
return "", ErrUnknownFileHeader
}
for _, header := range headerTypes {
if header.f(buf) {
return header.driver, nil
// doesn't exist, match file extension
ext := filepath.Ext(name)
for _, typ := range fileTypes {
if typ.ext.MatchString(ext) {
return typ.driver, nil
}
}
return "", ErrUnknownFileHeader
return "", ErrUnknownFileExtension
}

// Error is an error.
Expand All @@ -275,6 +282,8 @@ const (
ErrUnknownDatabaseScheme Error = "unknown database scheme"
// ErrUnknownFileHeader is the unknown file header error.
ErrUnknownFileHeader Error = "unknown file header"
// ErrUnknownFileExtension is the unknown file extension error.
ErrUnknownFileExtension Error = "unknown file extension"
// ErrInvalidTransportProtocol is the invalid transport protocol error.
ErrInvalidTransportProtocol Error = "invalid transport protocol"
// ErrRelativePathNotSupported is the relative paths not supported error.
Expand All @@ -286,3 +295,106 @@ const (
// ErrMissingUser is the missing user error.
ErrMissingUser Error = "missing user"
)

// Stat is the default stat func.
//
// Used internally to stat files, and used when generating the DSNs for
// postgres://, mysql://, file:// schemes, and opaque [URL]'s.
var Stat = func(name string) (fs.FileInfo, error) {
return fs.Stat(os.DirFS(filepath.Dir(name)), filepath.Base(name))
}

// OpenFile is the default open file func.
//
// Used internally to read file headers.
var OpenFile = func(name string) (fs.File, error) {
f, err := os.OpenFile(name, os.O_RDONLY, 0)
if err != nil {
return nil, err
}
return f, nil
}

// buildOpaque builds a opaque path from u.
func buildOpaque(u *URL) string {
var q string
if u.RawQuery != "" {
q = "?" + u.RawQuery
}
var f string
if u.Fragment != "" {
f = "#" + u.Fragment
}
return u.Opaque + q + f
}

// resolveType tries to resolve a path to a Unix domain socket or directory.
func resolveType(s string) (string, bool) {
dir := s
for dir != "" && dir != "/" && dir != "." {
// chop off :4444 port
i, j := strings.LastIndex(dir, ":"), strings.LastIndex(dir, "/")
if i != -1 && i > j {
dir = dir[:i]
}
switch fi, err := Stat(dir); {
case err == nil && fi.IsDir():
return "postgres", true
case err == nil && fi.Mode()&fs.ModeSocket != 0:
return "mysql", true
case err == nil:
return "", false
}
if j != -1 {
dir = dir[:j]
} else {
dir = ""
}
}
return "", false
}

// resolveSocket tries to resolve a path to a Unix domain socket based on the
// form "/path/to/socket/dbname" returning either the original path and the
// empty string, or the components "/path/to/socket" and "dbname", when
// /path/to/socket/dbname is reported by Stat as a socket.
func resolveSocket(s string) (string, string) {
dir, dbname := s, ""
for dir != "" && dir != "/" && dir != "." {
if mode(dir)&fs.ModeSocket != 0 {
return dir, dbname
}
dir, dbname = path.Dir(dir), path.Base(dir)
}
return s, ""
}

// resolveDir resolves a directory with a :port list.
func resolveDir(s string) (string, string, string) {
dir := s
for dir != "" && dir != "/" && dir != "." {
port := ""
i, j := strings.LastIndex(dir, ":"), strings.LastIndex(dir, "/")
if i != -1 && i > j {
port, dir = dir[i+1:], dir[:i]
}
if mode(dir)&fs.ModeDir != 0 {
dbname := strings.TrimPrefix(strings.TrimPrefix(strings.TrimPrefix(s, dir), ":"+port), "/")
return dir, port, dbname
}
if j != -1 {
dir = dir[:j]
} else {
dir = ""
}
}
return s, "", ""
}

// mode returns the mode of the path.
func mode(s string) os.FileMode {
if fi, err := Stat(s); err == nil {
return fi.Mode()
}
return 0
}
Loading

0 comments on commit 610d3ae

Please sign in to comment.