tasker.go 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. // Copyright 2019 getensh.com. All rights reserved.
  2. // Use of this source code is governed by getensh.com.
  3. package transtasker
  4. import (
  5. "fmt"
  6. "log"
  7. "gorm.io/gorm"
  8. )
  9. // Tasker 任务接口
  10. type Tasker interface {
  11. Exec(db *gorm.DB) error
  12. Rollback(db *gorm.DB) error
  13. }
  14. // Exec 执行任务
  15. func Exec(gormdb *gorm.DB, tasks ...Tasker) error {
  16. // 如果gormdb句柄为nil,直接返回错误
  17. if gormdb == nil {
  18. log.Println("error: gormdb is nil")
  19. return fmt.Errorf("gormdb is nil")
  20. }
  21. var err error
  22. db := gormdb.Begin()
  23. commited := false
  24. oktasks := make([]Tasker, 0)
  25. defer func() {
  26. // 捕获异常,并执行回滚操作
  27. if v := recover(); v != nil || !commited {
  28. for _, task := range oktasks {
  29. task.Rollback(db)
  30. }
  31. // db 统一回滚
  32. db.Rollback()
  33. // 向上层抛出异常
  34. if v != nil {
  35. panic(v)
  36. }
  37. }
  38. }()
  39. defer func() {
  40. if err != nil {
  41. for _, task := range oktasks {
  42. task.Rollback(db)
  43. }
  44. // db 统一回滚
  45. db.Rollback()
  46. }
  47. }()
  48. // 顺序执行任务
  49. for _, task := range tasks {
  50. if task != nil {
  51. if err = task.Exec(db); err != nil {
  52. return err
  53. }
  54. }
  55. oktasks = append(oktasks, task)
  56. }
  57. // 事务提交
  58. db.Commit()
  59. commited = true
  60. return nil
  61. }