Skip to content

Instantly share code, notes, and snippets.

@vbogretsov
Created August 17, 2021 19:04
Show Gist options
  • Save vbogretsov/0a3987ef6ce0540e3207052652535a96 to your computer and use it in GitHub Desktop.
Save vbogretsov/0a3987ef6ce0540e3207052652535a96 to your computer and use it in GitHub Desktop.
Gorm transaction wrapper
type GormTx struct {
stack []*gorm.DB
open int
clos int
}
func NewTransaction(db *gorm.DB) *GormTx {
return &GormTx{
stack: []*gorm.DB{db},
open: 0,
clos: 0,
}
}
func (tx *GormTx) db() *gorm.DB {
return tx.stack[tx.open]
}
func (tx *GormTx) Begin() error {
db := tx.db().Begin()
if db.Error != nil {
return db.Error
}
tx.stack = append(tx.stack, db)
tx.open++
tx.clos++
return nil
}
func (tx *GormTx) Commit() error {
if tx.open == 0 {
return errors.New("commit failed because transactoin wasn't started")
}
id := tx.open
db := tx.db().Commit()
tx.open--
if db.Error != nil {
return db.Error
}
tx.stack[id] = nil
return nil
}
func (tx *GormTx) Close() error {
if tx.clos == 0 {
return nil
}
if tx.stack[tx.clos] == nil {
return nil
}
db := tx.stack[tx.clos].Rollback()
tx.clos--
tx.open = tx.clos
return db.Error
}
t.Run("Atomic", func(t *testing.T) {
sess := model.Session{
ID: "atomic.session.123",
Value: "atomic.session.value.123",
Created: 1600000000,
Expires: 1600000010,
}
user := model.User{
ID: "atomic.user.123",
Name: "[email protected]",
Created: 1600000000,
}
refresh := model.RefreshToken{
ID: "atomic.refresh.123",
UserID: user.ID,
Created: 1600000000,
Expires: 1600000010,
}
t.Run("Rollback", func(t *testing.T) {
tx := repo.NewTransaction(db)
sessRepo := repo.NewSessions(tx)
userRepo := repo.NewUsers(tx)
refreshRepo := repo.NewRefreshTokens(tx)
test := func(tx repo.Transaction) {
require.NoError(t, tx.Begin())
defer func() { require.NoError(t, tx.Close()) }()
require.NoError(t, sessRepo.Create(sess))
require.NoError(t, userRepo.Create(user))
require.NoError(t, refreshRepo.Create(refresh))
}
test(tx)
_, err = sessRepo.Find(sess.ID)
require.ErrorIs(t, err, repo.ErrorNotFound)
_, err = userRepo.Find(user.Name)
require.ErrorIs(t, err, repo.ErrorNotFound)
_, err = refreshRepo.Find(refresh.ID)
require.ErrorIs(t, err, repo.ErrorNotFound)
})
t.Run("Commit", func(t *testing.T) {
tx := repo.NewTransaction(db)
xsrfRepo := repo.NewSessions(tx)
userRepo := repo.NewUsers(tx)
refreshRepo := repo.NewRefreshTokens(tx)
test := func(tx repo.Transaction) {
require.NoError(t, tx.Begin())
defer func() { require.NoError(t, tx.Close()) }()
require.NoError(t, xsrfRepo.Create(sess))
require.NoError(t, userRepo.Create(user))
require.NoError(t, refreshRepo.Create(refresh))
require.NoError(t, tx.Commit())
}
test(tx)
_, err = xsrfRepo.Find(sess.ID)
require.NoError(t, err)
_, err = userRepo.Find(user.Name)
require.NoError(t, err)
_, err = refreshRepo.Find(refresh.ID)
require.NoError(t, err)
})
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment