// Copyright 2019 getensh.com. All rights reserved. // Use of this source code is governed by getensh.com. package transtasker import ( "fmt" "log" "gorm.io/gorm" ) // Tasker 任务接口 type Tasker interface { Exec(db *gorm.DB) error Rollback(db *gorm.DB) error } // Exec 执行任务 func Exec(gormdb *gorm.DB, tasks ...Tasker) error { // 如果gormdb句柄为nil,直接返回错误 if gormdb == nil { log.Println("error: gormdb is nil") return fmt.Errorf("gormdb is nil") } var err error db := gormdb.Begin() commited := false oktasks := make([]Tasker, 0) defer func() { // 捕获异常,并执行回滚操作 if v := recover(); v != nil || !commited { for _, task := range oktasks { task.Rollback(db) } // db 统一回滚 db.Rollback() // 向上层抛出异常 if v != nil { panic(v) } } }() defer func() { if err != nil { for _, task := range oktasks { task.Rollback(db) } // db 统一回滚 db.Rollback() } }() // 顺序执行任务 for _, task := range tasks { if task != nil { if err = task.Exec(db); err != nil { return err } } oktasks = append(oktasks, task) } // 事务提交 db.Commit() commited = true return nil }