Skip to content

Commit

Permalink
added the ability to truncate all tables
Browse files Browse the repository at this point in the history
  • Loading branch information
markbates committed Mar 19, 2017
1 parent 061d092 commit 0238061
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 2 deletions.
4 changes: 4 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ func (c *Connection) Q() *Query {
return Q(c)
}

func (c *Connection) TruncateAll() error {
return c.Dialect.TruncateAll(c)
}

func (c *Connection) timeFunc(name string, fn func() error) error {
now := time.Now()
err := fn()
Expand Down
1 change: 1 addition & 0 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type dialect interface {
LoadSchema(io.Reader) error
FizzTranslator() fizz.Translator
Lock(func() error) error
TruncateAll(*Connection) error
}

func genericCreate(s store, model *Model, cols Columns) error {
Expand Down
17 changes: 17 additions & 0 deletions mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,27 @@ func (m *mysql) LoadSchema(r io.Reader) error {
return nil
}

func (m *mysql) TruncateAll(tx *Connection) error {
stmts := []struct {
Stmt string `db:"stmt"`
}{}
err := tx.RawQuery(mysqlTruncate, m.Details().Database).All(&stmts)
if err != nil {
return err
}
qs := []string{}
for _, x := range stmts {
qs = append(qs, x.Stmt)
}
return tx.RawQuery(strings.Join(qs, " ")).Exec()
}

func newMySQL(deets *ConnectionDetails) dialect {
cd := &mysql{
ConnectionDetails: deets,
}

return cd
}

const mysqlTruncate = "SELECT concat('TRUNCATE TABLE `', TABLE_NAME, '`;') as stmt FROM INFORMATION_SCHEMA.TABLES where TABLE_SCHEMA = ?"
17 changes: 17 additions & 0 deletions postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ func (p *postgresql) LoadSchema(r io.Reader) error {
return nil
}

func (p *postgresql) TruncateAll(tx *Connection) error {
return tx.RawQuery(pgTruncate).Exec()
}

func newPostgreSQL(deets *ConnectionDetails) dialect {
cd := &postgresql{
ConnectionDetails: deets,
Expand All @@ -197,3 +201,16 @@ func newPostgreSQL(deets *ConnectionDetails) dialect {
}
return cd
}

const pgTruncate = `DO
$func$
BEGIN
EXECUTE
(SELECT 'TRUNCATE TABLE '
|| string_agg(quote_ident(schemaname) || '.' || quote_ident(tablename), ', ')
|| ' CASCADE'
FROM pg_tables
WHERE schemaname = 'public'
);
END
$func$;`
24 changes: 22 additions & 2 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ func (m *sqlite) FizzTranslator() fizz.Translator {
}

func (m *sqlite) DumpSchema(w io.Writer) error {
cmd := exec.Command("sqlite3", m.URL(), ".schema")
cmd := exec.Command("sqlite3", m.Details().Database, ".schema")
Log(strings.Join(cmd.Args, " "))
cmd.Stdout = w
cmd.Stderr = os.Stderr
Expand All @@ -129,7 +129,7 @@ func (m *sqlite) DumpSchema(w io.Writer) error {
}

func (m *sqlite) LoadSchema(r io.Reader) error {
cmd := exec.Command("sqlite3", m.URL())
cmd := exec.Command("sqlite3", m.ConnectionDetails.Database)
in, err := cmd.StdinPipe()
if err != nil {
return err
Expand All @@ -153,6 +153,26 @@ func (m *sqlite) LoadSchema(r io.Reader) error {
return nil
}

func (m *sqlite) TruncateAll(tx *Connection) error {
const tableNames = `SELECT name FROM sqlite_master WHERE type = "table"`
names := []struct {
Name string `db:"name"`
}{}

err := tx.RawQuery(tableNames).All(&names)
if err != nil {
return err
}
stmts := []string{}
for _, n := range names {
stmts = append(stmts, fmt.Sprintf("DELETE FROM %s", n.Name))
}
if len(stmts) == 0 {
return nil
}
return tx.RawQuery(strings.Join(stmts, "; ")).Exec()
}

func newSQLite(deets *ConnectionDetails) dialect {
deets.URL = fmt.Sprintf("sqlite3://%s", deets.Database)
cd := &sqlite{
Expand Down

0 comments on commit 0238061

Please sign in to comment.