db.go 6.2 KB


  1. // Copyright 2019 autocareai.com. All rights reserved.
  2. // Use of this source code is governed by autocareai.com.
  3. package db
  4. import (
  5. "context"
  6. "fmt"
  7. "go.uber.org/zap"
  8. "gorm.io/driver/mysql"
  9. "gorm.io/gorm"
  10. "gorm.io/gorm/logger"
  11. "gorm.io/gorm/utils"
  12. "go.uber.org/zap/zapcore"
  13. "log"
  14. "os"
  15. "reflect"
  16. "time"
  17. )
  18. var (
  19. db *gorm.DB
  20. )
  21. // Colors
  22. const (
  23. Reset = "\033[0m"
  24. Red = "\033[31m"
  25. Green = "\033[32m"
  26. Yellow = "\033[33m"
  27. Blue = "\033[34m"
  28. Magenta = "\033[35m"
  29. Cyan = "\033[36m"
  30. White = "\033[37m"
  31. BlueBold = "\033[34;1m"
  32. MagentaBold = "\033[35;1m"
  33. RedBold = "\033[31;1m"
  34. YellowBold = "\033[33;1m"
  35. )
  36. // Writer log writer interface
  37. type Writer interface {
  38. Printf(string, ...interface{})
  39. }
  40. type CustomLogger struct {
  41. Writer
  42. Config
  43. infoStr, warnStr, errStr string
  44. traceStr, traceErrStr, traceWarnStr string
  45. }
  46. type Config struct {
  47. SlowThreshold time.Duration
  48. Colorful bool
  49. IgnoreRecordNotFoundError bool
  50. LogLevel zapcore.Level
  51. TraceWithLevel zapcore.Level
  52. Zap *zap.Logger
  53. }
  54. type callback struct{}
  55. // 处理gorm v2.0 find 为空返回nil的问题
  56. func (c callback) ParserError(db *gorm.DB) {
  57. if db.Error == nil && db.RowsAffected == 0 && (db.Statement.ReflectValue.Kind() == reflect.Slice || db.Statement.ReflectValue.Kind() == reflect.Struct) {
  58. db.Error = gorm.ErrRecordNotFound
  59. }
  60. }
  61. func RegisterCallback(db *gorm.DB) {
  62. _ = db.Callback().Query().After("gorm:query").Register("error_parser", callback{}.ParserError)
  63. }
  64. var gormLogLevelMap = map[logger.LogLevel]zapcore.Level{
  65. logger.Info: zap.InfoLevel,
  66. logger.Warn: zap.WarnLevel,
  67. logger.Error: zap.ErrorLevel,
  68. }
  69. // LogMode log mode
  70. func (c CustomLogger) LogMode(level logger.LogLevel) logger.Interface {
  71. zapLevel, exist := gormLogLevelMap[level]
  72. if !exist {
  73. zapLevel = zap.DebugLevel
  74. }
  75. newlogger := c
  76. newlogger.LogLevel = zapLevel
  77. newlogger.TraceWithLevel = zapLevel
  78. return &newlogger
  79. }
  80. func (c CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) {
  81. if c.Config.LogLevel <= zap.InfoLevel && c.Config.Zap != nil {
  82. c.Config.Zap.Sugar().Infof(msg, data...)
  83. }
  84. }
  85. func (c CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
  86. if c.Config.LogLevel <= zap.WarnLevel && c.Config.Zap != nil {
  87. c.Config.Zap.Sugar().Warnf(msg, data...)
  88. }
  89. }
  90. func (c CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) {
  91. if c.Config.LogLevel <= zap.ErrorLevel && c.Config.Zap != nil {
  92. c.Config.Zap.Sugar().Errorf(msg, data...)
  93. }
  94. }
  95. func (c CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) {
  96. elapsed := time.Since(begin)
  97. sql, rows := fc()
  98. if c.Config.LogLevel <= zap.InfoLevel {
  99. c.Printf(c.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
  100. }
  101. if c.Config.Zap == nil {
  102. return
  103. }
  104. switch {
  105. case err != nil:
  106. c.Config.Zap.Error("sql: "+sql, zap.Float64("elapsed", elapsed.Seconds()), zap.Int64("rows", rows), zap.String("error", err.Error()))
  107. case c.Config.SlowThreshold != 0 && elapsed.Seconds() > c.Config.SlowThreshold.Seconds():
  108. c.Config.Zap.Warn("sql: "+sql, zap.Float64("elapsed", elapsed.Seconds()), zap.Int64("rows", rows), zap.Float64("threshold", c.Config.SlowThreshold.Seconds()))
  109. case c.Config.LogLevel == zap.DebugLevel:
  110. log := c.Config.Zap.Debug
  111. if c.Config.TraceWithLevel == zap.InfoLevel {
  112. log = c.Config.Zap.Info
  113. } else if c.Config.TraceWithLevel == zap.WarnLevel {
  114. log = c.Config.Zap.Warn
  115. } else if c.Config.TraceWithLevel == zap.ErrorLevel {
  116. log = c.Config.Zap.Error
  117. }
  118. log("sql: "+sql, zap.Float64("elapsed", elapsed.Seconds()), zap.Int64("rows", rows))
  119. }
  120. }
  121. // NewGormLogger 返回带 zap logger 的 GormLogger
  122. func NewGormLogger(writer Writer, config Config) CustomLogger {
  123. var (
  124. infoStr = "%s\n[info] "
  125. warnStr = "%s\n[warn] "
  126. errStr = "%s\n[error] "
  127. traceStr = "%s\n[%.3fms] [rows:%v] %s"
  128. traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s"
  129. traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s"
  130. )
  131. if config.Colorful {
  132. infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset
  133. warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset
  134. errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
  135. traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
  136. traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset
  137. traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s"
  138. }
  139. return CustomLogger{
  140. Writer: writer,
  141. Config: config,
  142. infoStr: infoStr,
  143. warnStr: warnStr,
  144. errStr: errStr,
  145. traceStr: traceStr,
  146. traceWarnStr: traceWarnStr,
  147. traceErrStr: traceErrStr,
  148. }
  149. }
  150. // Setup 建立连接
  151. func Setup(user, passwd, addr, dbname, charset string, maxIdle, maxConn int, logMode bool,logger *zap.Logger) *gorm.DB {
  152. conf := Config{
  153. SlowThreshold: time.Second,
  154. Colorful: true,
  155. IgnoreRecordNotFoundError: false,
  156. LogLevel: zap.ErrorLevel,
  157. TraceWithLevel: zap.ErrorLevel,
  158. Zap: logger,
  159. }
  160. // 是否开启调试模式
  161. if logMode {
  162. conf.LogLevel = zap.DebugLevel
  163. conf.TraceWithLevel = zap.DebugLevel
  164. conf.SlowThreshold = time.Millisecond * 200
  165. }
  166. config := &gorm.Config{
  167. Logger: NewGormLogger(log.New(os.Stdout, "\r\n", log.LstdFlags), conf),
  168. }
  169. // 组装参数
  170. dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=%s&parseTime=True&loc=Local",
  171. user, passwd, addr, dbname, charset)
  172. // 打开新连接
  173. var err error
  174. db, err = gorm.Open(mysql.Open(dsn), config)
  175. if err != nil {
  176. log.Fatal("open mysql connection failed. err: ", err)
  177. }
  178. // 其他设置
  179. sqlDB, err := db.DB()
  180. if err != nil {
  181. log.Fatal("open mysql connection failed. err: ", err)
  182. }
  183. fmt.Println("数据库MaxIdle和MaxOpen:",maxIdle,maxConn)
  184. sqlDB.SetMaxIdleConns(maxIdle)
  185. sqlDB.SetMaxOpenConns(maxConn)
  186. RegisterCallback(db)
  187. return db
  188. }
  189. // DB 获取连接
  190. func DB() *gorm.DB {
  191. return db
  192. }
  193. // Close 关闭连接
  194. func Close() {
  195. if db != nil {
  196. sqlDB, _ := db.DB()
  197. _ = sqlDB.Close()
  198. }
  199. }