rbac_access.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package rbac
  2. import (
  3. "gd_admin/apis"
  4. "gd_admin/errors"
  5. "gd_admin/impl/dbmodel"
  6. "fmt"
  7. "strconv"
  8. "strings"
  9. "gd_admin/common.in/utils"
  10. "github.com/astaxie/beego/orm"
  11. "go.uber.org/zap"
  12. "golang.org/x/net/context"
  13. )
  14. // 获取用户具体权限列表
  15. func GetUserAccess(uid int64) (map[string][]string, []string, error) {
  16. node := make(map[string][]string, 0)
  17. resource := make([]string, 0)
  18. // 获取分组id
  19. p := dbmodel.TGdAdminRbacAccess{}
  20. // where
  21. filter := map[string]interface{}{
  22. "uid": uid,
  23. }
  24. err := p.Fetch(orm.NewOrm(), filter)
  25. if err != nil {
  26. if err == orm.ErrNoRows {
  27. return node, resource, errors.AccessNotAllow
  28. }
  29. l.Error("mysql",
  30. zap.String("sql", fmt.Sprintf("SELECT * FROM %s", p.TableName())),
  31. zap.String("fields", utils.MarshalJsonString(filter)),
  32. zap.String("error", err.Error()))
  33. return node, resource, errors.DataBaseError
  34. }
  35. // 获取节点id
  36. nodeIds, err := getUserRbacNode(p.GroupId)
  37. if err != nil {
  38. return node, resource, err
  39. }
  40. if nodeIds == "" {
  41. return node, resource, nil
  42. }
  43. //转换id类型
  44. ids := strings.Split(nodeIds, ",")
  45. nodeId := make([]int, len(ids))
  46. for k, v := range ids {
  47. id, _ := strconv.Atoi(v)
  48. nodeId[k] = id
  49. }
  50. // where
  51. where := map[string]interface{}{
  52. "id__in": nodeId,
  53. }
  54. n := dbmodel.TGdAdminRbacNode{}
  55. list, err := n.FetchAll(orm.NewOrm(), where, []string{"id", "pid", "resource", "object"})
  56. if err != nil {
  57. l.Error("mysql",
  58. zap.String("sql", fmt.Sprintf("SELECT * FROM %s", n.TableName())),
  59. zap.String("fields", utils.MarshalJsonString(where)),
  60. zap.String("error", err.Error()))
  61. return node, resource, errors.DataBaseError
  62. }
  63. for _, v := range list {
  64. if v.Pid == 0 {
  65. // 根节点
  66. if _, ok := node[v.Resource]; !ok {
  67. node[v.Resource] = make([]string, 0)
  68. resource = append(resource, v.Resource)
  69. }
  70. } else {
  71. // 子节点
  72. node[v.Resource] = append(node[v.Resource], v.Object)
  73. }
  74. }
  75. return node, resource, nil
  76. }
  77. // 用户刷新页面时获取权限列表
  78. func GetAccess(ctx context.Context, req *apis.GetAccessReq, reply *apis.GetAccessReply) (err error) {
  79. // 验证参数
  80. if req.Uid <= 0 {
  81. return errors.ArgsError
  82. }
  83. // 获取权限列表
  84. reply.Access, reply.Resource, err = GetUserAccess(req.Uid)
  85. if err != nil {
  86. return err
  87. }
  88. return nil
  89. }
  90. func UpdateAccess(db orm.Ormer, uid, groupId int64) error {
  91. // where
  92. filter := map[string]interface{}{
  93. "uid": uid,
  94. }
  95. // value
  96. value := map[string]interface{}{
  97. "group_id": groupId,
  98. }
  99. p := dbmodel.TGdAdminRbacAccess{}
  100. _, err := p.Save(db, filter, value)
  101. if err != nil {
  102. l.Error("mysql",
  103. zap.String("sql", fmt.Sprintf("Update %s", p.TableName())),
  104. zap.String("fields", utils.MarshalJsonString(filter, value)),
  105. zap.String("error", err.Error()))
  106. return errors.DataBaseError
  107. }
  108. return nil
  109. }
  110. // 新增访问权限
  111. func AddAccess(db orm.Ormer, uid, groupId int64) error {
  112. // where
  113. p := dbmodel.TGdAdminRbacAccess{
  114. Uid: int(uid),
  115. GroupId: int(groupId),
  116. }
  117. _, err := p.Create(db)
  118. if err != nil {
  119. l.Error("mysql",
  120. zap.String("sql", fmt.Sprintf("INSERT %s", p.TableName())),
  121. zap.String("fields", utils.MarshalJsonString(p)),
  122. zap.String("error", err.Error()))
  123. return errors.DataBaseError
  124. }
  125. return nil
  126. }