diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 00000000..36df95c0 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,22 @@ +Copyright (c) 2015 Mark Bates + +MIT License + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/README.md b/README.md index d497aa3e..5b932b09 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,17 @@ So what does Pop do exactly? Well, it wraps the absolutely amazing [https://gith Pop makes it easy to do CRUD operations, run migrations, and build/execute queries. Is Pop an ORM? I'll leave that up to you, the reader, to decide. +Pop, by default, follows conventions that were defined by the ActiveRecord Ruby gem, http://www.rubyonrails.org. What does this mean? + +* Tables must have an "id" column and a corresponding "ID" field on the `struct` being used. +* If there is a timestamp column named "created_at", "CreatedAt" on the `struct`, it will be set with the current time when the record is created. +* If there is a timestamp column named "updated_at", "UpdatedAt" on the `struct`, it will be set with the current time when the record is updated. +* Default databases are lowercase, underscored versions of the `struct` name. Examples: User{} is "users", FooBar{} is "foo_bars", etc... + +## Docs + +The API docs can be found at [http://godoc.org/github.com/markbates/pop](http://godoc.org/github.com/markbates/pop) + ## Connecting to Databases Pop is easily configured using a YAML file. The configuration file should be stored in `config/database.yml` or `database.yml`. diff --git a/belongs_to.go b/belongs_to.go index 5704f5dd..de247ef4 100644 --- a/belongs_to.go +++ b/belongs_to.go @@ -12,7 +12,7 @@ func (c *Connection) BelongsTo(model interface{}) *Query { // "model" passed into it. func (q *Query) BelongsTo(model interface{}) *Query { m := &Model{Value: model} - q.Where(fmt.Sprintf("%s = ?", m.AssociationName()), m.ID()) + q.Where(fmt.Sprintf("%s = ?", m.associationName()), m.ID()) return q } diff --git a/connection.go b/connection.go index daa56a34..56385f95 100644 --- a/connection.go +++ b/connection.go @@ -8,8 +8,11 @@ import ( "github.com/markbates/going/defaults" ) +// Connections contains all of the available connections var Connections = map[string]*Connection{} +// Connection represents all of the necessary details for +// talking with a datastore type Connection struct { Store Store Dialect Dialect @@ -20,6 +23,8 @@ func (c *Connection) String() string { return c.Dialect.URL() } +// NewConnection creates a new connection, and sets it's `Dialect` +// appropriately based on the `ConnectionDetails` passed into it. func NewConnection(deets *ConnectionDetails) *Connection { c := &Connection{ Timings: []time.Duration{}, @@ -35,12 +40,20 @@ func NewConnection(deets *ConnectionDetails) *Connection { return c } +// Connect takes the name of a connection, default is "development", and will +// return that connection from the available `Connections`. If a connection with +// that name can not be found an error will be returned. If a connection is +// found, and it has yet to open a connection with its underlying datastore, +// a connection to that store will be opened. func Connect(e string) (*Connection, error) { e = defaults.String(e, "development") c := Connections[e] if c == nil { return c, fmt.Errorf("Could not find connection named %s!", e) } + if c.Store != nil { + return c, nil + } db, err := sqlx.Open(c.Dialect.Details().Dialect, c.Dialect.URL()) if err == nil { c.Store = &dB{db} @@ -48,6 +61,9 @@ func Connect(e string) (*Connection, error) { return c, nil } +// Transaction will start a new transaction on the connection. If the inner function +// returns an error then the transaction will be rolled back, otherwise the transaction +// will automatically commit at the end. func (c *Connection) Transaction(fn func(tx *Connection) error) error { tx, err := c.Store.Transaction() if err != nil { @@ -67,6 +83,8 @@ func (c *Connection) Transaction(fn func(tx *Connection) error) error { return err } +// Rollback will open a new transaction and automatically rollback that transaction +// when the inner function returns, regardless. This can be useful for tests, etc... func (c *Connection) Rollback(fn func(tx *Connection)) error { tx, err := c.Store.Transaction() if err != nil { @@ -80,6 +98,8 @@ func (c *Connection) Rollback(fn func(tx *Connection)) error { fn(cn) return tx.Rollback() } + +// Q creates a new "empty" query for the current connection. func (c *Connection) Q() *Query { return Q(c) } diff --git a/dialect.go b/dialect.go index 0a98ee25..8011e32a 100644 --- a/dialect.go +++ b/dialect.go @@ -32,7 +32,7 @@ func genericCreate(store Store, model *Model, cols Columns) error { } id, err = res.LastInsertId() if err == nil { - model.SetID(int(id)) + model.setID(int(id)) } return err } diff --git a/doc.go b/doc.go new file mode 100644 index 00000000..35ef81ad --- /dev/null +++ b/doc.go @@ -0,0 +1,13 @@ +/* +So what does Pop do exactly? Well, it wraps the absolutely amazing https://github.com/jmoiron/sqlx library. It cleans up some of the common patterns and workflows usually associated with dealing with databases in Go. + +Pop makes it easy to do CRUD operations, run migrations, and build/execute queries. Is Pop an ORM? I'll leave that up to you, the reader, to decide. + +Pop, by default, follows conventions that were defined by the ActiveRecord Ruby gem, http://www.rubyonrails.org. What does this mean? + +* Tables must have an "id" column and a corresponding "ID" field on the `struct` being used. +* If there is a timestamp column named "created_at", "CreatedAt" on the `struct`, it will be set with the current time when the record is created. +* If there is a timestamp column named "updated_at", "UpdatedAt" on the `struct`, it will be set with the current time when the record is updated. +* Default databases are lowercase, underscored versions of the `struct` name. Examples: User{} is "users", FooBar{} is "foo_bars", etc... +*/ +package pop diff --git a/executors.go b/executors.go index 3538c7e0..3d2d91d3 100644 --- a/executors.go +++ b/executors.go @@ -31,8 +31,8 @@ func (c *Connection) Create(model interface{}, excludeColumns ...string) error { cols := ColumnsForStruct(model, sm.TableName()) cols.Remove(excludeColumns...) - sm.TouchCreatedAt() - sm.TouchUpdatedAt() + sm.touchCreatedAt() + sm.touchUpdatedAt() return c.Dialect.Create(c.Store, sm, cols) }) @@ -46,7 +46,7 @@ func (c *Connection) Update(model interface{}, excludeColumns ...string) error { cols.Remove("id", "created_at") cols.Remove(excludeColumns...) - sm.TouchUpdatedAt() + sm.touchUpdatedAt() return c.Dialect.Update(c.Store, sm, cols) }) diff --git a/finders.go b/finders.go index 15e64151..bd05fb6f 100644 --- a/finders.go +++ b/finders.go @@ -2,18 +2,30 @@ package pop import "reflect" +// Find the first record of the model in the database with a particular id. +// +// c.Find(&User{}, 1) func (c *Connection) Find(model interface{}, id int) error { return Q(c).Find(model, id) } +// Find the first record of the model in the database with a particular id. +// +// q.Find(&User{}, 1) func (q *Query) Find(model interface{}, id int) error { return q.Where("id = ?", id).First(model) } +// First record of the model in the database that matches the query. +// +// c.First(&User{}) func (c *Connection) First(model interface{}) error { return Q(c).First(model) } +// First record of the model in the database that matches the query. +// +// q.Where("name = ?", "mark").First(&User{}) func (q *Query) First(model interface{}) error { return q.Connection.timeFunc("First", func() error { q.Limit(1) @@ -22,10 +34,16 @@ func (q *Query) First(model interface{}) error { }) } +// Last record of the model in the database that matches the query. +// +// c.Last(&User{}) func (c *Connection) Last(model interface{}) error { return Q(c).Last(model) } +// Last record of the model in the database that matches the query. +// +// q.Where("name = ?", "mark").Last(&User{}) func (q *Query) Last(model interface{}) error { return q.Connection.timeFunc("Last", func() error { q.Limit(1) @@ -35,10 +53,16 @@ func (q *Query) Last(model interface{}) error { }) } +// All retrieves all of the records in the database that match the query. +// +// c.All(&[]User{}) func (c *Connection) All(models interface{}) error { return Q(c).All(models) } +// All retrieves all of the records in the database that match the query. +// +// q.Where("name = ?", "mark").All(&[]User{}) func (q *Query) All(models interface{}) error { return q.Connection.timeFunc("All", func() error { m := &Model{Value: models} @@ -59,19 +83,25 @@ func (q *Query) All(models interface{}) error { }) } -func (c *Connection) Exists(model interface{}) (bool, error) { - return Q(c).Exists(model) -} - +// Exists returns true/false if a record exists in the database that matches +// the query. +// +// q.Where("name = ?", "mark").Exists(&User{}) func (q *Query) Exists(model interface{}) (bool, error) { i, err := q.Count(model) return i != 0, err } +// Count the number of records in the database. +// +// c.Count(&User{}) func (c *Connection) Count(model interface{}) (int, error) { return Q(c).Count(model) } +// Count the number of records in the database. +// +// q.Where("name = ?", "mark").Count(&User{}) func (q Query) Count(model interface{}) (int, error) { res := &rowCount{} err := q.Connection.timeFunc("Count", func() error { diff --git a/model.go b/model.go index 8d7ee71f..8c9e2169 100644 --- a/model.go +++ b/model.go @@ -5,78 +5,60 @@ import ( "reflect" "time" - "github.com/markbates/going/validate" "github.com/markbates/inflect" ) var tableMap = map[string]string{} +// MapTableName allows for the customize table mapping +// between a name and the database. For example the value +// `User{}` will automatically map to "users". +// MapTableName would allow this to change. +// +// m := &pop.Model{Value: User{}} +// m.TableName() // "users" +// +// pop.MapTableName("user", "people") +// m = &pop.Model{Value: User{}} +// m.TableName() // "people" func MapTableName(name string, tableName string) { tableMap[name] = tableName } type Value interface{} +// Model is used throughout Pop to wrap the end user interface +// that is passed in to many functions. type Model struct { Value tableName string } -func (m *Model) FieldByName(s string) (reflect.Value, error) { - el := reflect.ValueOf(m.Value).Elem() - fbn := el.FieldByName(s) - if !fbn.IsValid() { - return fbn, fmt.Errorf("Model does not have a field named %s", s) - } - return fbn, nil -} - -func (m *Model) AssociationName() string { - tn := inflect.Singularize(m.TableName()) - return fmt.Sprintf("%s_id", tn) -} - -func (m *Model) Validate(*Connection) (*validate.Errors, error) { - return validate.NewErrors(), nil -} - -func (m *Model) ValidateNew(*Connection) (*validate.Errors, error) { - return validate.NewErrors(), nil -} - -func (m *Model) ValidateUpdate(*Connection) (*validate.Errors, error) { - return validate.NewErrors(), nil -} - +// func (m *Model) Validate(*Connection) (*validate.Errors, error) { +// return validate.NewErrors(), nil +// } +// +// func (m *Model) ValidateNew(*Connection) (*validate.Errors, error) { +// return validate.NewErrors(), nil +// } +// +// func (m *Model) ValidateUpdate(*Connection) (*validate.Errors, error) { +// return validate.NewErrors(), nil +// } + +// ID returns the ID of the Model. All models must have an `ID` field this is +// of type `int`. func (m *Model) ID() int { - fbn, err := m.FieldByName("ID") + fbn, err := m.fieldByName("ID") if err != nil { return 0 } return int(fbn.Int()) } -func (m *Model) SetID(i int) { - fbn, err := m.FieldByName("ID") - if err == nil { - fbn.SetInt(int64(i)) - } -} - -func (m *Model) TouchCreatedAt() { - fbn, err := m.FieldByName("CreatedAt") - if err == nil { - fbn.Set(reflect.ValueOf(time.Now())) - } -} - -func (m *Model) TouchUpdatedAt() { - fbn, err := m.FieldByName("UpdatedAt") - if err == nil { - fbn.Set(reflect.ValueOf(time.Now())) - } -} - +// TableName returns the corresponding name of the underlying database table +// for a given `Model`. See also `MapTableName` to change the default name of +// the table. func (m *Model) TableName() string { if m.tableName != "" { return m.tableName @@ -100,3 +82,38 @@ func (m *Model) TableName() string { } return tableMap[name] } + +func (m *Model) fieldByName(s string) (reflect.Value, error) { + el := reflect.ValueOf(m.Value).Elem() + fbn := el.FieldByName(s) + if !fbn.IsValid() { + return fbn, fmt.Errorf("Model does not have a field named %s", s) + } + return fbn, nil +} + +func (m *Model) associationName() string { + tn := inflect.Singularize(m.TableName()) + return fmt.Sprintf("%s_id", tn) +} + +func (m *Model) setID(i int) { + fbn, err := m.fieldByName("ID") + if err == nil { + fbn.SetInt(int64(i)) + } +} + +func (m *Model) touchCreatedAt() { + fbn, err := m.fieldByName("CreatedAt") + if err == nil { + fbn.Set(reflect.ValueOf(time.Now())) + } +} + +func (m *Model) touchUpdatedAt() { + fbn, err := m.fieldByName("UpdatedAt") + if err == nil { + fbn.Set(reflect.ValueOf(time.Now())) + } +} diff --git a/paginator.go b/paginator.go index 4b9846b1..cff0c177 100644 --- a/paginator.go +++ b/paginator.go @@ -3,8 +3,16 @@ package pop import ( "encoding/json" "strconv" + + "github.com/markbates/going/defaults" ) +var PaginatorPerPageDefault = 20 +var PaginatorPageKey = "page" +var PaginatorPerPageKey = "per_page" + +// Paginator is a type used to represent the pagination of records +// from the database. type Paginator struct { // Current page you're on Page int `json:"page"` @@ -25,6 +33,8 @@ func (p Paginator) String() string { return string(b) } +// NewPaginator returns a new `Paginator` value with the appropriate +// defaults set. func NewPaginator(page int, per_page int) *Paginator { p := &Paginator{Page: page, PerPage: per_page} p.Offset = (p.Page - 1) * p.PerPage @@ -35,16 +45,15 @@ type PaginationParams interface { Get(key string) string } +// NewPaginatorFromParams takes an interface of type `PaginationParams`, +// the `url.Values` type works great with this interface, and returns +// a new `Paginator` based on the params or `PaginatorPageKey` and +// `PaginatorPerPageKey`. Defaults are `1` for the page and +// PaginatorPerPageDefault for the per page value. func NewPaginatorFromParams(params PaginationParams) *Paginator { - page := params.Get("page") - if page == "" { - page = "1" - } + page := defaults.String(params.Get("page"), "1") - per_page := params.Get("per_page") - if per_page == "" { - per_page = "20" - } + per_page := defaults.String(params.Get("per_page"), strconv.Itoa(PaginatorPerPageDefault)) p, err := strconv.Atoi(page) if err != nil { @@ -53,24 +62,44 @@ func NewPaginatorFromParams(params PaginationParams) *Paginator { pp, err := strconv.Atoi(per_page) if err != nil { - pp = 20 + pp = PaginatorPerPageDefault } return NewPaginator(p, pp) } +// Paginate records returned from the database. +// +// q := c.Paginate(2, 15) +// q.All(&[]User{}) +// q.Paginator func (c *Connection) Paginate(page int, per_page int) *Query { return Q(c).Paginate(page, per_page) } +// Paginate records returned from the database. +// +// q = q.Paginate(2, 15) +// q.All(&[]User{}) +// q.Paginator func (q *Query) Paginate(page int, per_page int) *Query { q.Paginator = NewPaginator(page, per_page) return q } +// Paginate records returned from the database. +// +// q := c.PaginateFromParams(req.URL.Query()) +// q.All(&[]User{}) +// q.Paginator func (c *Connection) PaginateFromParams(params PaginationParams) *Query { return Q(c).PaginateFromParams(params) } +// Paginate records returned from the database. +// +// q = q.PaginateFromParams(req.URL.Query()) +// q.All(&[]User{}) +// q.Paginator func (q *Query) PaginateFromParams(params PaginationParams) *Query { q.Paginator = NewPaginatorFromParams(params) return q diff --git a/postgresql.go b/postgresql.go index 2a93c0f6..6ed588ac 100644 --- a/postgresql.go +++ b/postgresql.go @@ -33,7 +33,7 @@ func (p *PostgreSQL) Create(store Store, model *Model, cols Columns) error { } err = stmt.Get(&id, model.Value) if err == nil { - model.SetID(id.ID) + model.setID(id.ID) } return err } diff --git a/query.go b/query.go index f3259fb2..1aeed8d6 100644 --- a/query.go +++ b/query.go @@ -1,5 +1,7 @@ package pop +// Query is the main value that is used to build up a query +// to be executed against the `Connection`. type Query struct { RawSQL *clause limitResults int @@ -11,42 +13,69 @@ type Query struct { Connection *Connection } +// RawQuery will override the query building feature of Pop and will use +// whatever query you want to execute against the `Connection`. You can continue +// to use the `?` argument syntax. +// +// c.RawQuery("select * from foo where id = ?", 1) func (c *Connection) RawQuery(stmt string, args ...interface{}) *Query { return Q(c).RawQuery(stmt, args...) } +// RawQuery will override the query building feature of Pop and will use +// whatever query you want to execute against the `Connection`. You can continue +// to use the `?` argument syntax. +// +// q.RawQuery("select * from foo where id = ?", 1) func (q *Query) RawQuery(stmt string, args ...interface{}) *Query { q.RawSQL = &clause{stmt, args} return q } +// Where will append a where clause to the query. You may use `?` in place of +// arguments. +// +// c.Where("id = ?", 1) func (c *Connection) Where(stmt string, args ...interface{}) *Query { return Q(c).Where(stmt, args...) } +// Where will append a where clause to the query. You may use `?` in place of +// arguments. +// +// q.Where("id = ?", 1) func (q *Query) Where(stmt string, args ...interface{}) *Query { q.whereClauses = append(q.whereClauses, clause{stmt, args}) return q } +// Order will append an order clause to the query. +// +// c.Order("name desc") func (c *Connection) Order(stmt string) *Query { return Q(c).Order(stmt) } +// Order will append an order clause to the query. +// +// q.Order("name desc") func (q *Query) Order(stmt string) *Query { q.orderClauses = append(q.orderClauses, clause{stmt, []interface{}{}}) return q } +// Limit will add a limit clause to the query. func (c *Connection) Limit(limit int) *Query { return Q(c).Limit(limit) } +// Limit will add a limit clause to the query. func (q *Query) Limit(limit int) *Query { q.limitResults = limit return q } +// Q will create a new "empty" query from the current connection. func Q(c *Connection) *Query { return &Query{ RawSQL: &clause{}, @@ -54,11 +83,15 @@ func Q(c *Connection) *Query { } } +// ToSQL will generate SQL and the appropriate arguments for that SQL +// from the `Model` passed in. func (q Query) ToSQL(model *Model, addColumns ...string) (string, []interface{}) { - sb := NewSQLBuilder(q, model, addColumns...) + sb := q.ToSQLBuilder(model, addColumns...) return sb.String(), sb.Args() } +// ToSQLBuilder returns a new `SQLBuilder` that can be used to generate SQL, +// get arguments, and more. func (q Query) ToSQLBuilder(model *Model, addColumns ...string) *SQLBuilder { return NewSQLBuilder(q, model, addColumns...) } diff --git a/scopes.go b/scopes.go index 8def8656..dceb996d 100644 --- a/scopes.go +++ b/scopes.go @@ -2,10 +2,28 @@ package pop type ScopeFunc func(q *Query) *Query +// Scope the query by using a `ScopeFunc` +// +// func ByName(name string) ScopeFunc { +// return func(q *Query) *Query { +// return q.Where("name = ?", name) +// } +// } +// +// q.Scope(ByName("mark").Where("id = ?", 1).First(&User{}) func (q *Query) Scope(sf ScopeFunc) *Query { return sf(q) } +// Scope the query by using a `ScopeFunc` +// +// func ByName(name string) ScopeFunc { +// return func(q *Query) *Query { +// return q.Where("name = ?", name) +// } +// } +// +// c.Scope(ByName("mark")).First(&User{}) func (c *Connection) Scope(sf ScopeFunc) *Query { return Q(c).Scope(sf) } diff --git a/scopes_test.go b/scopes_test.go index 6b2505ee..9b57e9b9 100644 --- a/scopes_test.go +++ b/scopes_test.go @@ -1,7 +1,6 @@ package pop_test import ( - "os" "testing" "github.com/markbates/pop" @@ -10,24 +9,18 @@ import ( func Test_Scopes(t *testing.T) { r := require.New(t) - oql := "SELECT name as full_name, users.alive, users.bio, users.birth_date, users.created_at, users.id, users.name, users.price, users.updated_at FROM users AS users" + oql := "SELECT enemies.A FROM enemies AS enemies" - transaction(func(tx *pop.Connection) { - u := &pop.Model{Value: &User{}} - q := tx.Q() + m := &pop.Model{Value: &Enemy{}} - s, _ := q.ToSQL(u) - r.Equal(oql, s) + q := PDB.Q() + s, _ := q.ToSQL(m) + r.Equal(oql, s) - q.Scope(func(qy *pop.Query) *pop.Query { - return qy.Where("id = ?", 1) - }) - - s, _ = q.ToSQL(u) - if os.Getenv("SODA_DIALECT") == "postgres" { - r.Equal(oql+" WHERE id = $1", s) - } else { - r.Equal(oql+" WHERE id = ?", s) - } + q.Scope(func(qy *pop.Query) *pop.Query { + return qy.Where("id = ?", 1) }) + + s, _ = q.ToSQL(m) + r.Equal(ts(oql+" WHERE id = ?"), s) } diff --git a/sql_builder.go b/sql_builder.go index 337f5b5f..39143eaa 100644 --- a/sql_builder.go +++ b/sql_builder.go @@ -113,8 +113,8 @@ func (sq *SQLBuilder) buildfromClauses() fromClauses { func (sq *SQLBuilder) buildWhereClauses(sql string) string { mcs := sq.Query.belongsToThroughClauses for _, mc := range mcs { - sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.AssociationName()), mc.BelongsTo.ID()) - sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.AssociationName())) + sq.Query.Where(fmt.Sprintf("%s.%s = ?", mc.Through.TableName(), mc.BelongsTo.associationName()), mc.BelongsTo.ID()) + sq.Query.Where(fmt.Sprintf("%s.id = %s.%s", sq.Model.TableName(), mc.Through.TableName(), sq.Model.associationName())) } wc := sq.Query.whereClauses