create.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. package manual_task
  2. import (
  3. "adm-management/consts"
  4. "adm-management/errors"
  5. "adm-management/model"
  6. "adm-management/parser"
  7. v1 "adm-management/pb/v1"
  8. "context"
  9. "encoding/json"
  10. "fmt"
  11. "strconv"
  12. "strings"
  13. "gorm.io/gorm"
  14. "git.getensh.com/common/gopkgsv2/database"
  15. "git.getensh.com/common/gopkgsv2/logger"
  16. "go.uber.org/zap"
  17. "google.golang.org/grpc/status"
  18. )
  19. type OdsMessage struct {
  20. MsgType string `json:"msg_type"`
  21. SourceCode string `json:"source_code"` // 来源编码
  22. OfflineTaskId int64 `json:"offline_task_id"` // 离线消息任务id
  23. TaskList []int `json:"task_list"` // 任务列表
  24. From int `json:"from"` // 消息来源类型 1 数据库 2 excel
  25. Content string `json:"content"`
  26. Timestamp int64 `json:"timestamp"`
  27. }
  28. func Create(ctx context.Context, req *v1.CreateRequest) (reply *v1.CreateReply, err error) {
  29. reply = &v1.CreateReply{}
  30. // 捕获各个task中的异常并返回给调用者
  31. defer func() {
  32. if r := recover(); r != nil {
  33. err = fmt.Errorf("%+v", r)
  34. e := &status.Status{}
  35. if er := json.Unmarshal([]byte(err.Error()), e); er != nil {
  36. logger.Error("err",
  37. zap.String("system_err", err.Error()),
  38. zap.Stack("stacktrace"))
  39. }
  40. }
  41. }()
  42. db := database.DB().Begin()
  43. err = createImpl(db, req)
  44. if err != nil {
  45. db.Rollback()
  46. } else {
  47. db.Commit()
  48. }
  49. return reply, err
  50. }
  51. func createImpl(db *gorm.DB, req *v1.CreateRequest) (err error) {
  52. req.Sql = strings.TrimSpace(req.Sql)
  53. req.Sql = strings.ToLower(req.Sql)
  54. if req.Type != consts.FromDb && req.Type != consts.FromExcel {
  55. return errors.ParamsError
  56. }
  57. if req.Type == consts.FromExcel && len(req.TaskIds) <= 0 {
  58. return errors.ParamsError
  59. }
  60. // 遍历任务名改为任务id
  61. var taskIds string
  62. var ids []string
  63. for _, i := range req.TaskIds {
  64. tasks, err := model.NewTaskList().Get(db.Where("task_name = ?", i))
  65. if err != nil && err != gorm.ErrRecordNotFound {
  66. return err
  67. }
  68. if err == gorm.ErrRecordNotFound {
  69. fmt.Println("not found")
  70. return nil
  71. }
  72. ids = append(ids, strconv.Itoa(int(tasks.TaskId)))
  73. }
  74. taskIds = strings.Join(ids, ",")
  75. // 构建消息
  76. odsMsg := OdsMessage{MsgType: consts.ODSOFFLINEIMPORT, From: int(req.Type)}
  77. if len(ids) > 0 {
  78. for _, v := range ids {
  79. taskId, _ := strconv.Atoi(v)
  80. odsMsg.TaskList = append(odsMsg.TaskList, taskId)
  81. }
  82. }
  83. offlineTask := &model.OfflineTask{TaskName: req.TaskName, Type: req.Type, Source: req.Source, Sql: req.Sql, TaskIds: taskIds}
  84. // TODO 事务处理
  85. err = offlineTask.Insert(db)
  86. if err != nil {
  87. if err == gorm.ErrRecordNotFound {
  88. return nil
  89. }
  90. return errors.SystemError
  91. }
  92. odsMsg.OfflineTaskId = offlineTask.ID
  93. // 选择源表
  94. if req.Type == 1 {
  95. // 通过source 获取库表
  96. //dataSource, err := model.NewDataList().Get(database.DB().Where("source_code = ?", req.Source))
  97. dataSource, err := model.NewDataList().Get(database.DB().Where("table_name = ?", req.Source))
  98. if err != nil && err != gorm.ErrRecordNotFound {
  99. return err
  100. }
  101. odsMsg.SourceCode = dataSource.SourceCode
  102. if req.Sql == "" {
  103. odsMsg.Content = fmt.Sprintf("select * from %s.%s", dataSource.Db, dataSource.TableName)
  104. } else if strings.HasPrefix(req.Sql, "where") {
  105. odsMsg.Content = fmt.Sprintf("select * from %s.%s %s", dataSource.Db, dataSource.TableName, req.Sql)
  106. } else {
  107. odsMsg.Content = fmt.Sprintf("select * from %s.%s where %s", dataSource.Db, dataSource.TableName, req.Sql)
  108. }
  109. // 判断sql是否正确
  110. err = db.Exec(fmt.Sprintf("%s limit 1", odsMsg.Content)).Error
  111. if err != nil {
  112. return errors.SqlError
  113. }
  114. } else {
  115. contentList := strings.Split(req.Source, "/")
  116. odsMsg.Content = contentList[len(contentList)-1]
  117. }
  118. // 发送消息
  119. odsMsgByte, err := json.Marshal(odsMsg)
  120. if err != nil {
  121. return err
  122. }
  123. err = parser.OdsMq.PublishMsg(odsMsgByte)
  124. if err != nil {
  125. return err
  126. }
  127. return nil
  128. }