socket.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. package util
  2. import (
  3. "crypto/sha256"
  4. "crypto/subtle"
  5. "encoding/json"
  6. "net"
  7. "net/http"
  8. "strconv"
  9. "time"
  10. "gopkg.in/yaml.v3"
  11. )
  12. func FetchValue[T any](t T) func() T {
  13. return func() T {
  14. return t
  15. }
  16. }
  17. const (
  18. APIErrorNone = 0
  19. APIErrorDecode = iota + 4000
  20. APIErrorQueryParse
  21. APIErrorNoBody
  22. )
  23. const (
  24. APIErrorNotFound = iota + 4040
  25. APIErrorNoStream
  26. APIErrorNoConfig
  27. APIErrorNoPusher
  28. APIErrorNoSubscriber
  29. APIErrorNoSEI
  30. )
  31. const (
  32. APIErrorInternal = iota + 5000
  33. APIErrorJSONEncode
  34. APIErrorPublish
  35. APIErrorSave
  36. APIErrorOpen
  37. )
  38. type APIError struct {
  39. Code int `json:"code"`
  40. Message string `json:"msg"`
  41. }
  42. type APIResult struct {
  43. Code int `json:"code"`
  44. Data any `json:"data"`
  45. Message string `json:"msg"`
  46. }
  47. func ReturnValue(v any, rw http.ResponseWriter, r *http.Request) {
  48. ReturnFetchValue(FetchValue(v), rw, r)
  49. }
  50. func ReturnOK(rw http.ResponseWriter, r *http.Request) {
  51. ReturnError(0, "ok", rw, r)
  52. }
  53. func ReturnError(code int, msg string, rw http.ResponseWriter, r *http.Request) {
  54. query := r.URL.Query()
  55. isJson := query.Get("format") == "json"
  56. if isJson {
  57. if err := json.NewEncoder(rw).Encode(APIError{code, msg}); err != nil {
  58. json.NewEncoder(rw).Encode(APIError{
  59. Code: APIErrorJSONEncode,
  60. Message: err.Error(),
  61. })
  62. }
  63. } else {
  64. switch true {
  65. case code == 0:
  66. http.Error(rw, msg, http.StatusOK)
  67. case code/10 == 404:
  68. http.Error(rw, msg, http.StatusNotFound)
  69. case code > 5000:
  70. http.Error(rw, msg, http.StatusInternalServerError)
  71. default:
  72. http.Error(rw, msg, http.StatusBadRequest)
  73. }
  74. }
  75. }
  76. func ReturnFetchList[T any](fetch func() []T, rw http.ResponseWriter, r *http.Request) {
  77. query := r.URL.Query()
  78. isYaml := query.Get("format") == "yaml"
  79. isJson := query.Get("format") == "json"
  80. pageSize := query.Get("pageSize")
  81. pageNum := query.Get("pageNum")
  82. data := fetch()
  83. var output any
  84. output = data
  85. if pageSize != "" && pageNum != "" {
  86. pageSizeInt, _ := strconv.Atoi(pageSize)
  87. pageNumInt, _ := strconv.Atoi(pageNum)
  88. if pageSizeInt > 0 && pageNumInt > 0 {
  89. start := (pageNumInt - 1) * pageSizeInt
  90. end := pageNumInt * pageSizeInt
  91. if start > len(data) {
  92. start = len(data)
  93. }
  94. if end > len(data) {
  95. end = len(data)
  96. }
  97. output = map[string]any{
  98. "total": len(data),
  99. "list": data[start:end],
  100. "pageSize": pageSizeInt,
  101. "pageNum": pageNumInt,
  102. }
  103. }
  104. }
  105. rw.Header().Set("Content-Type", Conditoinal(isYaml, "text/yaml", "application/json"))
  106. if isYaml {
  107. if err := yaml.NewEncoder(rw).Encode(output); err != nil {
  108. http.Error(rw, err.Error(), http.StatusInternalServerError)
  109. }
  110. } else if isJson {
  111. if err := json.NewEncoder(rw).Encode(APIResult{
  112. Code: 0,
  113. Data: output,
  114. Message: "ok",
  115. }); err != nil {
  116. json.NewEncoder(rw).Encode(APIError{
  117. Code: APIErrorJSONEncode,
  118. Message: err.Error(),
  119. })
  120. }
  121. } else {
  122. if err := json.NewEncoder(rw).Encode(output); err != nil {
  123. http.Error(rw, err.Error(), http.StatusInternalServerError)
  124. }
  125. }
  126. }
  127. func ReturnFetchValue[T any](fetch func() T, rw http.ResponseWriter, r *http.Request) {
  128. query := r.URL.Query()
  129. isYaml := query.Get("format") == "yaml"
  130. isJson := query.Get("format") == "json"
  131. tickDur, err := time.ParseDuration(query.Get("interval"))
  132. if err != nil {
  133. tickDur = time.Second
  134. }
  135. if r.Header.Get("Accept") == "text/event-stream" {
  136. sse := NewSSE(rw, r.Context())
  137. tick := time.NewTicker(tickDur)
  138. defer tick.Stop()
  139. writer := Conditoinal(isYaml, sse.WriteYAML, sse.WriteJSON)
  140. writer(fetch())
  141. for range tick.C {
  142. if writer(fetch()) != nil {
  143. return
  144. }
  145. }
  146. } else {
  147. data := fetch()
  148. rw.Header().Set("Content-Type", Conditoinal(isYaml, "text/yaml", "application/json"))
  149. if isYaml {
  150. if err := yaml.NewEncoder(rw).Encode(data); err != nil {
  151. http.Error(rw, err.Error(), http.StatusInternalServerError)
  152. }
  153. } else if isJson {
  154. if err := json.NewEncoder(rw).Encode(APIResult{
  155. Code: 0,
  156. Data: data,
  157. Message: "ok",
  158. }); err != nil {
  159. json.NewEncoder(rw).Encode(APIError{
  160. Code: APIErrorJSONEncode,
  161. Message: err.Error(),
  162. })
  163. }
  164. } else {
  165. if err := json.NewEncoder(rw).Encode(data); err != nil {
  166. http.Error(rw, err.Error(), http.StatusInternalServerError)
  167. }
  168. }
  169. }
  170. }
  171. func ListenUDP(address string, networkBuffer int) (*net.UDPConn, error) {
  172. addr, err := net.ResolveUDPAddr("udp", address)
  173. if err != nil {
  174. return nil, err
  175. }
  176. conn, err := net.ListenUDP("udp", addr)
  177. if err != nil {
  178. return nil, err
  179. }
  180. if err = conn.SetReadBuffer(networkBuffer); err != nil {
  181. return nil, err
  182. }
  183. if err = conn.SetWriteBuffer(networkBuffer); err != nil {
  184. return nil, err
  185. }
  186. return conn, err
  187. }
  188. // CORS 加入跨域策略头包含CORP
  189. func CORS(next http.Handler) http.Handler {
  190. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  191. header := w.Header()
  192. header.Set("Access-Control-Allow-Credentials", "true")
  193. header.Set("Cross-Origin-Resource-Policy", "cross-origin")
  194. header.Set("Access-Control-Allow-Headers", "Content-Type,Access-Token")
  195. origin := r.Header["Origin"]
  196. if len(origin) == 0 {
  197. header.Set("Access-Control-Allow-Origin", "*")
  198. } else {
  199. header.Set("Access-Control-Allow-Origin", origin[0])
  200. }
  201. if next != nil && r.Method != "OPTIONS" {
  202. next.ServeHTTP(w, r)
  203. }
  204. })
  205. }
  206. func BasicAuth(u, p string, next http.Handler) http.Handler {
  207. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  208. // Extract the username and password from the request
  209. // Authorization header. If no Authentication header is present
  210. // or the header value is invalid, then the 'ok' return value
  211. // will be false.
  212. username, password, ok := r.BasicAuth()
  213. if ok {
  214. // Calculate SHA-256 hashes for the provided and expected
  215. // usernames and passwords.
  216. usernameHash := sha256.Sum256([]byte(username))
  217. passwordHash := sha256.Sum256([]byte(password))
  218. expectedUsernameHash := sha256.Sum256([]byte(u))
  219. expectedPasswordHash := sha256.Sum256([]byte(p))
  220. // 使用 subtle.ConstantTimeCompare() 进行校验
  221. // the provided username and password hashes equal the
  222. // expected username and password hashes. ConstantTimeCompare
  223. // 如果值相等,则返回1,否则返回0。
  224. // Importantly, we should to do the work to evaluate both the
  225. // username and password before checking the return values to
  226. // 避免泄露信息。
  227. usernameMatch := (subtle.ConstantTimeCompare(usernameHash[:], expectedUsernameHash[:]) == 1)
  228. passwordMatch := (subtle.ConstantTimeCompare(passwordHash[:], expectedPasswordHash[:]) == 1)
  229. // If the username and password are correct, then call
  230. // the next handler in the chain. Make sure to return
  231. // afterwards, so that none of the code below is run.
  232. if usernameMatch && passwordMatch {
  233. if next != nil {
  234. next.ServeHTTP(w, r)
  235. }
  236. return
  237. }
  238. }
  239. // If the Authentication header is not present, is invalid, or the
  240. // username or password is wrong, then set a WWW-Authenticate
  241. // header to inform the client that we expect them to use basic
  242. // authentication and send a 401 Unauthorized response.
  243. w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
  244. http.Error(w, "Unauthorized", http.StatusUnauthorized)
  245. })
  246. }