node_list_by_groupuser.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. package rbac
  2. import (
  3. "context"
  4. "cp-organization-management/errors"
  5. "cp-organization-management/model"
  6. pb_v1 "cp-organization-management/pb/v1"
  7. "cp-organization-management/utils"
  8. "encoding/json"
  9. "fmt"
  10. "github.com/jaryhe/gopkgs/database"
  11. "github.com/jaryhe/gopkgs/logger"
  12. "github.com/jinzhu/gorm"
  13. "go.uber.org/zap"
  14. "google.golang.org/grpc/status"
  15. "strconv"
  16. "strings"
  17. )
  18. func nodeSelect(m map[int64]bool, all []*pb_v1.RbacNodeItem) []*pb_v1.RbacNodeItem {
  19. for i, v := range all {
  20. if _, ok := m[v.Id]; ok {
  21. all[i].Select = true
  22. if len(all[i].Childs) > 0 {
  23. all[i].Childs = nodeSelect(m, all[i].Childs)
  24. }
  25. }
  26. }
  27. return all
  28. }
  29. func getNodeListByGroupId(id int64, code string)(reply *pb_v1.RbacNodeListByGroupOrUserReply, err error) {
  30. reply = &pb_v1.RbacNodeListByGroupOrUserReply{}
  31. dbname := utils.GetDbName(code)
  32. // 取id列表
  33. group := model.NewRbacGroup(dbname)
  34. where := map[string]interface{}{
  35. "id":id,
  36. }
  37. err = group.Find(database.DB(), where)
  38. if err != nil {
  39. return nil, status.Error(10012, "角色不存在")
  40. }
  41. if group.NodeList == "" {
  42. return reply, nil
  43. }
  44. idStrs := strings.Split(group.NodeList, ",")
  45. ids := make([]int64, len(idStrs))
  46. for i, v := range idStrs {
  47. ids[i], _ = strconv.ParseInt(v, 10, 64)
  48. }
  49. // 取目标group节点
  50. p := model.NewRbacNode(dbname)
  51. where = map[string]interface{}{
  52. "id in":ids,
  53. }
  54. list, err := p.ListAll(database.DB(), where, nil)
  55. if err != nil {
  56. if err == gorm.ErrRecordNotFound {
  57. return reply, nil
  58. }
  59. return nil, errors.DataBaseError
  60. }
  61. // 如果目标角色是超管,获取所有节点
  62. // 否则获取除super 外的所有节点
  63. mreq := pb_v1.RbacNodeListRequest{OrganizationCode:code, IsAll:false}
  64. if group.IsSuperGroup {
  65. mreq.IsAll = true
  66. }
  67. mreply, err := RbacNodeList(context.Background(), &mreq)
  68. if err != nil {
  69. return nil, err
  70. }
  71. m := map[int64]bool{}
  72. for _, v := range list {
  73. m[v.Id] = true
  74. }
  75. reply.List = nodeSelect(m, mreply.List)
  76. return reply, nil
  77. }
  78. func getNodeListByGroupIdOnlySelect(id int64, code string)(reply *pb_v1.RbacNodeListByGroupOrUserReply, err error) {
  79. reply = &pb_v1.RbacNodeListByGroupOrUserReply{}
  80. dbname := utils.GetDbName(code)
  81. // 取id列表
  82. group := model.NewRbacGroup(dbname)
  83. where := map[string]interface{}{
  84. "id":id,
  85. }
  86. err = group.Find(database.DB(), where)
  87. if err != nil {
  88. return nil, status.Error(10012, "角色不存在")
  89. }
  90. if group.NodeList == "" {
  91. return reply, nil
  92. }
  93. idStrs := strings.Split(group.NodeList, ",")
  94. ids := make([]int64, len(idStrs))
  95. for i, v := range idStrs {
  96. ids[i], _ = strconv.ParseInt(v, 10, 64)
  97. }
  98. // 取目标group节点
  99. p := model.NewRbacNode(dbname)
  100. where = map[string]interface{}{
  101. "id in":ids,
  102. }
  103. list, err := p.ListAll(database.DB(), where, nil)
  104. if err != nil {
  105. if err == gorm.ErrRecordNotFound {
  106. return reply, nil
  107. }
  108. return nil, errors.DataBaseError
  109. }
  110. reply.List = NodeTreeSelect(list)
  111. return reply, nil
  112. }
  113. func RbacNodeListByGroupOrUser(ctx context.Context, req *pb_v1.RbacNodeListByGroupOrUserRequest) (reply *pb_v1.RbacNodeListByGroupOrUserReply, err error) {
  114. reply = &pb_v1.RbacNodeListByGroupOrUserReply{}
  115. // 捕获各个task中的异常并返回给调用者
  116. defer func() {
  117. if r := recover(); r != nil {
  118. err = fmt.Errorf("%+v", r)
  119. e := &status.Status{}
  120. if er := json.Unmarshal([]byte(err.Error()), e); er != nil {
  121. logger.Error("err",
  122. zap.String("system_err", err.Error()),
  123. zap.Stack("stacktrace"))
  124. }
  125. }
  126. }()
  127. if req.OrganizationCode == "" {
  128. return nil, errors.ParamsError
  129. }
  130. if req.Uid < 1 && req.GroupId < 1 {
  131. return nil, errors.ParamsError
  132. }
  133. if req.GroupId > 0 {
  134. if req.Select {
  135. return getNodeListByGroupIdOnlySelect(req.GroupId, req.OrganizationCode)
  136. }
  137. return getNodeListByGroupId(req.GroupId, req.OrganizationCode)
  138. }
  139. dbname := utils.GetDbName(req.OrganizationCode)
  140. p := model.NewRbacUser(dbname)
  141. where := map[string]interface{}{
  142. "id":req.Uid,
  143. }
  144. err = p.Find(database.DB(), where)
  145. if err != nil {
  146. if err == gorm.ErrRecordNotFound {
  147. return nil, status.Error(30202, "用户不存在")
  148. }
  149. return nil, errors.DataBaseError
  150. }
  151. if req.Select {
  152. return getNodeListByGroupIdOnlySelect(p.GroupId, req.OrganizationCode)
  153. }
  154. return getNodeListByGroupId(p.GroupId, req.OrganizationCode)
  155. }