Skip to content

Commit

Permalink
Merge pull request #33 from asdine/support-connbegintx
Browse files Browse the repository at this point in the history
Add support for driver.ConnBeginTx
  • Loading branch information
asdine authored Jan 21, 2021
2 parents fc410b0 + 4803583 commit 9db12c7
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
8 changes: 8 additions & 0 deletions sqlhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ func (drv *Driver) Open(name string) (driver.Conn, error) {
return conn, err
}

// Drivers that don't implement driver.ConnBeginTx are not supported.
if _, ok := conn.(driver.ConnBeginTx); !ok {
return nil, errors.New("driver must implement driver.ConnBeginTx")
}

wrapped := &Conn{conn, drv.hooks}
if isExecer(conn) && isQueryer(conn) && isSessionResetter(conn) {
return &ExecerQueryerContextWithSessionResetter{wrapped,
Expand Down Expand Up @@ -97,6 +102,9 @@ func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt
func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) }
func (conn *Conn) Close() error { return conn.Conn.Close() }
func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() }
func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
}

// ExecerContext implements a database/sql.driver.ExecerContext
type ExecerContext struct {
Expand Down
25 changes: 25 additions & 0 deletions sqlhooks_interface_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
*FakeConnQueryer
*FakeConnSessionResetter
}{}, nil
case "NonConnBeginTx":
return &FakeConnUnsupported{}, nil
}

return nil, errors.New("Fake driver not implemented")
Expand All @@ -80,6 +82,9 @@ func (*FakeConnBasic) Close() error {
func (*FakeConnBasic) Begin() (driver.Tx, error) {
return nil, errors.New("Not implemented")
}
func (*FakeConnBasic) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) {
return nil, errors.New("Not implemented")
}

type FakeConnExecer struct{}

Expand Down Expand Up @@ -111,6 +116,20 @@ func (*FakeConnSessionResetter) ResetSession(ctx context.Context) error {
return errors.New("Not implemented")
}

// FakeConnUnsupported implements a database/sql.driver.Conn but doesn't implement
// driver.ConnBeginTx.
type FakeConnUnsupported struct{}

func (*FakeConnUnsupported) Prepare(query string) (driver.Stmt, error) {
return nil, errors.New("Not implemented")
}
func (*FakeConnUnsupported) Close() error {
return errors.New("Not implemented")
}
func (*FakeConnUnsupported) Begin() (driver.Tx, error) {
return nil, errors.New("Not implemented")
}

func TestInterfaces(t *testing.T) {
drv := Wrap(&fakeDriver{}, &testHooks{})

Expand All @@ -123,3 +142,9 @@ func TestInterfaces(t *testing.T) {
}
}
}

func TestUnsupportedDrivers(t *testing.T) {
drv := Wrap(&fakeDriver{}, &testHooks{})
_, err := drv.Open("NonConnBeginTx")
require.EqualError(t, err, "driver must implement driver.ConnBeginTx")
}

0 comments on commit 9db12c7

Please sign in to comment.