socket.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. package util
  2. import (
  3. "crypto/sha256"
  4. "crypto/subtle"
  5. "encoding/json"
  6. "fmt"
  7. "net"
  8. "net/http"
  9. "reflect"
  10. "strconv"
  11. "time"
  12. "gopkg.in/yaml.v3"
  13. )
  14. func FetchValue[T any](t T) func() T {
  15. return func() T {
  16. return t
  17. }
  18. }
  19. const (
  20. APIErrorNone = 0
  21. APIErrorDecode = iota + 4000
  22. APIErrorQueryParse
  23. APIErrorNoBody
  24. )
  25. const (
  26. APIErrorNotFound = iota + 4040
  27. APIErrorNoStream
  28. APIErrorNoConfig
  29. APIErrorNoPusher
  30. APIErrorNoSubscriber
  31. APIErrorNoSEI
  32. )
  33. const (
  34. APIErrorInternal = iota + 5000
  35. APIErrorJSONEncode
  36. APIErrorPublish
  37. APIErrorSave
  38. APIErrorOpen
  39. )
  40. type APIError struct {
  41. Code int `json:"code"`
  42. Message string `json:"msg"`
  43. }
  44. type APIResult struct {
  45. Code int `json:"code"`
  46. Data any `json:"data"`
  47. Message string `json:"msg"`
  48. }
  49. func ReturnValue(v any, rw http.ResponseWriter, r *http.Request) {
  50. ReturnFetchValue(FetchValue(v), rw, r)
  51. }
  52. func ReturnOK(rw http.ResponseWriter, r *http.Request) {
  53. ReturnError(0, "ok", rw, r)
  54. }
  55. func ReturnError(code int, msg string, rw http.ResponseWriter, r *http.Request) {
  56. query := r.URL.Query()
  57. isJson := query.Get("format") == "json"
  58. if isJson {
  59. if err := json.NewEncoder(rw).Encode(APIError{code, msg}); err != nil {
  60. json.NewEncoder(rw).Encode(APIError{
  61. Code: APIErrorJSONEncode,
  62. Message: err.Error(),
  63. })
  64. }
  65. } else {
  66. switch true {
  67. case code == 0:
  68. http.Error(rw, msg, http.StatusOK)
  69. case code/10 == 404:
  70. http.Error(rw, msg, http.StatusNotFound)
  71. case code > 5000:
  72. http.Error(rw, msg, http.StatusInternalServerError)
  73. default:
  74. http.Error(rw, msg, http.StatusBadRequest)
  75. }
  76. }
  77. }
  78. func ReturnFetchList[T any](fetch func() []T, rw http.ResponseWriter, r *http.Request) {
  79. query := r.URL.Query()
  80. isYaml := query.Get("format") == "yaml"
  81. isJson := query.Get("format") == "json"
  82. pageSize := query.Get("pageSize")
  83. pageNum := query.Get("pageNum")
  84. data := fetch()
  85. var output any
  86. output = data
  87. if pageSize != "" && pageNum != "" {
  88. pageSizeInt, _ := strconv.Atoi(pageSize)
  89. pageNumInt, _ := strconv.Atoi(pageNum)
  90. if pageSizeInt > 0 && pageNumInt > 0 {
  91. start := (pageNumInt - 1) * pageSizeInt
  92. end := pageNumInt * pageSizeInt
  93. if start > len(data) {
  94. start = len(data)
  95. }
  96. if end > len(data) {
  97. end = len(data)
  98. }
  99. output = map[string]any{
  100. "total": len(data),
  101. "list": data[start:end],
  102. "pageSize": pageSizeInt,
  103. "pageNum": pageNumInt,
  104. }
  105. }
  106. }
  107. rw.Header().Set("Content-Type", Conditoinal(isYaml, "text/yaml", "application/json"))
  108. if isYaml {
  109. if err := yaml.NewEncoder(rw).Encode(output); err != nil {
  110. http.Error(rw, err.Error(), http.StatusInternalServerError)
  111. }
  112. } else if isJson {
  113. if err := json.NewEncoder(rw).Encode(APIResult{
  114. Code: 0,
  115. Data: output,
  116. Message: "ok",
  117. }); err != nil {
  118. json.NewEncoder(rw).Encode(APIError{
  119. Code: APIErrorJSONEncode,
  120. Message: err.Error(),
  121. })
  122. }
  123. } else {
  124. if err := json.NewEncoder(rw).Encode(output); err != nil {
  125. http.Error(rw, err.Error(), http.StatusInternalServerError)
  126. }
  127. }
  128. }
  129. func ReturnFetchValue[T any](fetch func() T, rw http.ResponseWriter, r *http.Request) {
  130. query := r.URL.Query()
  131. isYaml := query.Get("format") == "yaml"
  132. isJson := query.Get("format") == "json"
  133. tickDur, err := time.ParseDuration(query.Get("interval"))
  134. if err != nil {
  135. tickDur = time.Second
  136. }
  137. if r.Header.Get("Accept") == "text/event-stream" {
  138. sse := NewSSE(rw, r.Context())
  139. tick := time.NewTicker(tickDur)
  140. defer tick.Stop()
  141. writer := Conditoinal(isYaml, sse.WriteYAML, sse.WriteJSON)
  142. writer(fetch())
  143. for range tick.C {
  144. if writer(fetch()) != nil {
  145. return
  146. }
  147. }
  148. } else {
  149. data := fetch()
  150. rw.Header().Set("Content-Type", Conditoinal(isYaml, "text/yaml", "application/json"))
  151. if isYaml {
  152. if err := yaml.NewEncoder(rw).Encode(data); err != nil {
  153. http.Error(rw, err.Error(), http.StatusInternalServerError)
  154. }
  155. } else if isJson {
  156. if err := json.NewEncoder(rw).Encode(APIResult{
  157. Code: 0,
  158. Data: data,
  159. Message: "ok",
  160. }); err != nil {
  161. json.NewEncoder(rw).Encode(APIError{
  162. Code: APIErrorJSONEncode,
  163. Message: err.Error(),
  164. })
  165. }
  166. } else {
  167. t := reflect.TypeOf(data)
  168. switch t.Kind() {
  169. case reflect.String, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
  170. rw.Header().Set("Content-Type", "text/plain")
  171. fmt.Fprint(rw, data)
  172. default:
  173. if err := json.NewEncoder(rw).Encode(data); err != nil {
  174. http.Error(rw, err.Error(), http.StatusInternalServerError)
  175. }
  176. }
  177. }
  178. }
  179. }
  180. func ListenUDP(address string, networkBuffer int) (*net.UDPConn, error) {
  181. addr, err := net.ResolveUDPAddr("udp", address)
  182. if err != nil {
  183. return nil, err
  184. }
  185. conn, err := net.ListenUDP("udp", addr)
  186. if err != nil {
  187. return nil, err
  188. }
  189. if err = conn.SetReadBuffer(networkBuffer); err != nil {
  190. return nil, err
  191. }
  192. if err = conn.SetWriteBuffer(networkBuffer); err != nil {
  193. return nil, err
  194. }
  195. return conn, err
  196. }
  197. // CORS 加入跨域策略头包含CORP
  198. func CORS(next http.Handler) http.Handler {
  199. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  200. header := w.Header()
  201. header.Set("Access-Control-Allow-Credentials", "true")
  202. header.Set("Cross-Origin-Resource-Policy", "cross-origin")
  203. header.Set("Access-Control-Allow-Headers", "Content-Type,Access-Token")
  204. origin := r.Header["Origin"]
  205. if len(origin) == 0 {
  206. header.Set("Access-Control-Allow-Origin", "*")
  207. } else {
  208. header.Set("Access-Control-Allow-Origin", origin[0])
  209. }
  210. if next != nil && r.Method != "OPTIONS" {
  211. next.ServeHTTP(w, r)
  212. }
  213. })
  214. }
  215. func BasicAuth(u, p string, next http.Handler) http.Handler {
  216. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  217. // Extract the username and password from the request
  218. // Authorization header. If no Authentication header is present
  219. // or the header value is invalid, then the 'ok' return value
  220. // will be false.
  221. username, password, ok := r.BasicAuth()
  222. if ok {
  223. // Calculate SHA-256 hashes for the provided and expected
  224. // usernames and passwords.
  225. usernameHash := sha256.Sum256([]byte(username))
  226. passwordHash := sha256.Sum256([]byte(password))
  227. expectedUsernameHash := sha256.Sum256([]byte(u))
  228. expectedPasswordHash := sha256.Sum256([]byte(p))
  229. // 使用 subtle.ConstantTimeCompare() 进行校验
  230. // the provided username and password hashes equal the
  231. // expected username and password hashes. ConstantTimeCompare
  232. // 如果值相等,则返回1,否则返回0。
  233. // Importantly, we should to do the work to evaluate both the
  234. // username and password before checking the return values to
  235. // 避免泄露信息。
  236. usernameMatch := (subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:]) == 1)
  237. passwordMatch := (subtle.ConstantTimeCompare(passwordHash[:], expectedPasswordHash[:]) == 1)
  238. // If the username and password are correct, then call
  239. // the next handler in the chain. Make sure to return
  240. // afterwards, so that none of the code below is run.
  241. if usernameMatch && passwordMatch {
  242. if next != nil {
  243. next.ServeHTTP(w, r)
  244. }
  245. return
  246. }
  247. }
  248. // If the Authentication header is not present, is invalid, or the
  249. // username or password is wrong, then set a WWW-Authenticate
  250. // header to inform the client that we expect them to use basic
  251. // authentication and send a 401 Unauthorized response.
  252. w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
  253. http.Error(w, "Unauthorized", http.StatusUnauthorized)
  254. })
  255. }