Skip to content

Commit

Permalink
Improve sql support
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 26, 2023
1 parent dc078bd commit fa001ff
Show file tree
Hide file tree
Showing 11 changed files with 306 additions and 30 deletions.
31 changes: 22 additions & 9 deletions chain/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ var _ schema.Chain = (*SQL)(nil)

type SQLOptions struct {
*schema.CallbackOptions
InputKey string
TablesInputKey string
OutputKey string
TopK uint
InputKey string
TablesInputKey string
OutputKey string
TopK uint
Schema string
Tables []string
Exclude []string
SampleRowsinTableInfo uint
}

type SQL struct {
Expand All @@ -46,17 +50,26 @@ type SQL struct {
opts SQLOptions
}

func NewSQL(llm schema.Model, engine sqldb.Engine) (*SQL, error) {
func NewSQL(llm schema.Model, engine sqldb.Engine, optFns ...func(o *SQLOptions)) (*SQL, error) {
opts := SQLOptions{
InputKey: "query",
OutputKey: "result",
TopK: 5,
InputKey: "query",
OutputKey: "result",
TopK: 5,
SampleRowsinTableInfo: 3,
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
}

sqldb, err := sqldb.New(engine)
for _, fn := range optFns {
fn(&opts)
}

sqldb, err := sqldb.New(engine, func(o *sqldb.SQLDBOptions) {
o.Tables = opts.Tables
o.Exclude = opts.Exclude
o.SampleRowsinTableInfo = opts.SampleRowsinTableInfo
})
if err != nil {
return nil, err
}
Expand Down
11 changes: 10 additions & 1 deletion docs/content/en/docs/chains/sql.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ title: SQL
description: All about sql chains.
weight: 70
---

```go
package main

Expand Down Expand Up @@ -34,6 +35,8 @@ func main() {
log.Fatal(err)
}

defer engine.Close()

// Only for demonstration
_, exErr := engine.Exec(ctx, "CREATE TABLE IF NOT EXISTS employee ( id int not null );")
if exErr != nil {
Expand Down Expand Up @@ -64,4 +67,10 @@ func main() {
Output:
```text
There are 4 employees.
```
```

## Supported databases
MySQL, MariaDB, PostgresSQL, SQLite, CockroachDB

## Golang SQL Drivers
https://github.com/golang/go/wiki/SQLDrivers
9 changes: 6 additions & 3 deletions examples/sql/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/hupe1980/golc/integration/sqldb"
"github.com/hupe1980/golc/model/llm"

// Add your sql db driver, see https://github.com/golang/go/wiki/SQLDrivers
_ "github.com/mattn/go-sqlite3"
)

Expand All @@ -27,15 +28,17 @@ func main() {
log.Fatal(err)
}

defer engine.Close()

_, exErr := engine.Exec(ctx, "CREATE TABLE IF NOT EXISTS employee ( id int not null );")
if exErr != nil {
log.Fatal(exErr)
}

for i := 0; i < 4; i++ {
_, qErr := engine.Exec(ctx, "INSERT INTO employee (id) VALUES (?) ;", i)
if qErr != nil {
log.Fatal(qErr)
_, iErr := engine.Exec(ctx, "INSERT INTO employee (id) VALUES (?);", i)
if iErr != nil {
log.Fatal(iErr)
}
}

Expand Down
41 changes: 41 additions & 0 deletions integration/sqldb/cockroachdb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package sqldb

// Compile time check to ensure CockroachDB satisfies the Engine interface.
var _ Engine = (*CockroachDB)(nil)

// CockroachDBOptions holds options for the CockroachDB database engine.
type CockroachDBOptions struct {
DriverName string
}

// CockroachDB represents the CockroachDB database engine.
type CockroachDB struct {
*Postgres
}

// NewCockroachDB creates a new instance of the CockroachDB database engine.
func NewCockroachDB(dataSourceName string, optFns ...func(o *CockroachDBOptions)) (*CockroachDB, error) {
opts := CockroachDBOptions{
DriverName: "pgx",
}

for _, fn := range optFns {
fn(&opts)
}

postgres, err := NewPostgres(dataSourceName, func(o *PostgresOptions) {
o.DriverName = opts.DriverName
})
if err != nil {
return nil, err
}

return &CockroachDB{
Postgres: postgres,
}, nil
}

// Dialect returns the dialect of the CockroachDB database engine.
func (e *CockroachDB) Dialect() string {
return "CockroachDB"
}
41 changes: 41 additions & 0 deletions integration/sqldb/mariadb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package sqldb

// Compile time check to ensure MariaDB satisfies the Engine interface.
var _ Engine = (*MariaDB)(nil)

// MariaDBOptions holds options for the MariaDB database engine.
type MariaDBOptions struct {
DriverName string
}

// MariaDB represents the MariaDB database engine.
type MariaDB struct {
*MySQL
}

// NewMariaDB creates a new instance of the MariaDB database engine.
func NewMariaDB(dataSourceName string, optFns ...func(o *MariaDBOptions)) (*MariaDB, error) {
opts := MariaDBOptions{
DriverName: "mysql",
}

for _, fn := range optFns {
fn(&opts)
}

mysql, err := NewMySQL(dataSourceName, func(o *MySQLOptions) {
o.DriverName = opts.DriverName
})
if err != nil {
return nil, err
}

return &MariaDB{
MySQL: mysql,
}, nil
}

// Dialect returns the dialect of the MariaDB database engine.
func (e *MariaDB) Dialect() string {
return "MariaDB"
}
19 changes: 18 additions & 1 deletion integration/sqldb/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,28 @@ import (
// Compile time check to ensure MySQL satisfies the Engine interface.
var _ Engine = (*MySQL)(nil)

// MySQLOptions holds options for the MySQL database engine.
type MySQLOptions struct {
DriverName string
}

// MySQL represents the MySQL database engine.
type MySQL struct {
db *sql.DB
*atlas
opts MySQLOptions
}

func NewMySQL(dataSourceName string) (*MySQL, error) {
// NewMySQL creates a new instance of the MySQL database engine.
func NewMySQL(dataSourceName string, optFns ...func(o *MySQLOptions)) (*MySQL, error) {
opts := MySQLOptions{
DriverName: "mysql",
}

for _, fn := range optFns {
fn(&opts)
}

db, err := sql.Open(opts.DriverName, dataSourceName)
if err != nil {
return nil, err
Expand All @@ -43,22 +50,32 @@ func NewMySQL(dataSourceName string) (*MySQL, error) {
}, nil
}

// Dialect returns the dialect of the MySQL database engine.
func (e *MySQL) Dialect() string {
return "MySQL"
}

// Exec executes an SQL query with the provided query string and arguments (args), returning the result and any errors encountered.
func (e *MySQL) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
return e.db.ExecContext(ctx, query, args...)
}

// Query executes an SQL query with the provided query string and arguments (args), returning the rows and any errors encountered.
func (e *MySQL) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return e.db.QueryContext(ctx, query, args...)
}

// QueryRow executes an SQL query with the provided query string and arguments (args), returning a single row and any errors encountered.
func (e *MySQL) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
return e.db.QueryRowContext(ctx, query, args...)
}

// SampleRowsQuery returns the query to retrieve a sample of rows from the specified table (table) with a limit of (k) rows.
func (e *MySQL) SampleRowsQuery(table string, k uint) string {
return fmt.Sprintf("SELECT * FROM %s LIMIT %d", table, k)
}

// Close closes the database connection.
func (e *MySQL) Close() error {
return e.db.Close()
}
19 changes: 18 additions & 1 deletion integration/sqldb/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,28 @@ import (
// Compile time check to ensure Postgres satisfies the Engine interface.
var _ Engine = (*Postgres)(nil)

// PostgresOptions holds options for the Postgres database engine.
type PostgresOptions struct {
DriverName string
}

// Postgres represents the Postgres database engine.
type Postgres struct {
db *sql.DB
*atlas
opts PostgresOptions
}

func NewPostgres(dataSourceName string) (*Postgres, error) {
// NewPostgres creates a new instance of the Postgres database engine.
func NewPostgres(dataSourceName string, optFns ...func(o *PostgresOptions)) (*Postgres, error) {
opts := PostgresOptions{
DriverName: "pgx",
}

for _, fn := range optFns {
fn(&opts)
}

db, err := sql.Open(opts.DriverName, dataSourceName)
if err != nil {
return nil, err
Expand All @@ -43,22 +50,32 @@ func NewPostgres(dataSourceName string) (*Postgres, error) {
}, nil
}

// Dialect returns the dialect of the Postgres database engine.
func (e *Postgres) Dialect() string {
return "Postgres"
}

// Exec executes an SQL query with the provided query string and arguments (args), returning the result and any errors encountered.
func (e *Postgres) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
return e.db.ExecContext(ctx, query, args...)
}

// Query executes an SQL query with the provided query string and arguments (args), returning the rows and any errors encountered.
func (e *Postgres) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return e.db.QueryContext(ctx, query, args...)
}

// QueryRow executes an SQL query with the provided query string and arguments (args), returning a single row and any errors encountered.
func (e *Postgres) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
return e.db.QueryRowContext(ctx, query, args...)
}

// SampleRowsQuery returns the query to retrieve a sample of rows from the specified table (table) with a limit of (k) rows.
func (e *Postgres) SampleRowsQuery(table string, k uint) string {
return fmt.Sprintf("SELECT * FROM %s LIMIT %d", table, k)
}

// Close closes the database connection.
func (e *Postgres) Close() error {
return e.db.Close()
}
Loading

0 comments on commit fa001ff

Please sign in to comment.