gateway.go 5.4 KB


  1. package gateway
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/smallnest/rpcx/log"
  7. "github.com/smallnest/rpcx/share"
  8. "io"
  9. "mime"
  10. "mime/multipart"
  11. "net"
  12. "net/http"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "github.com/smallnest/rpcx/client"
  18. )
  19. // ServiceHandler converts http.Request into rpcx.Request and send it to rpcx service,
  20. // and then converts the result and writes it into http.Response.
  21. // You should get the http.Request and servicePath in your web handler.
  22. type ServiceHandler func(*http.Request, string) (map[string]string, []byte, error)
  23. // HTTPServer is a golang web interface。
  24. // You can use echo, gin, iris or other go web frameworks to implement it.
  25. // You must wrap ServiceHandler into your handler of your selected web framework and add it into router.
  26. type HTTPServer interface {
  27. RegisterHandler(base string, handler ServiceHandler)
  28. Serve() error
  29. }
  30. // Gateway is a rpcx gateway which can convert http invoke into rpcx invoke.
  31. type Gateway struct {
  32. base string
  33. httpserver HTTPServer
  34. serviceDiscovery client.ServiceDiscovery
  35. FailMode client.FailMode
  36. SelectMode client.SelectMode
  37. Option client.Option
  38. mu sync.RWMutex
  39. xclients map[string]client.XClient
  40. seq uint64
  41. }
  42. // NewGateway returns a new gateway.
  43. func NewGateway(base string, hs HTTPServer, sd client.ServiceDiscovery, failMode client.FailMode, selectMode client.SelectMode, option client.Option) *Gateway {
  44. // base is empty or like /abc/
  45. if base == "" {
  46. base = "/"
  47. }
  48. if base[0] != '/' {
  49. base = "/" + base
  50. }
  51. g := &Gateway{
  52. base: base,
  53. httpserver: hs,
  54. serviceDiscovery: sd,
  55. FailMode: failMode,
  56. SelectMode: selectMode,
  57. Option: option,
  58. xclients: make(map[string]client.XClient),
  59. }
  60. hs.RegisterHandler(base, g.handler)
  61. return g
  62. }
  63. // Serve listens on the TCP network address addr and then calls
  64. // Serve with handler to handle requests on incoming connections.
  65. // Accepted connections are configured to enable TCP keep-alives.
  66. func (g *Gateway) Serve() error {
  67. return g.httpserver.Serve()
  68. }
  69. func (g *Gateway) handler(r *http.Request, servicePath string) (meta map[string]string, payload []byte, err error) {
  70. contentType := r.Header.Get("Content-Type")
  71. mediaType, _, err := mime.ParseMediaType(contentType)
  72. var xc client.XClient
  73. g.mu.Lock()
  74. if mediaType == gin.MIMEMultipartPOSTForm {
  75. xc, err = getXClient(g, share.StreamServiceName)
  76. } else {
  77. xc, err = getXClient(g, servicePath)
  78. }
  79. g.mu.Unlock()
  80. if err != nil {
  81. return nil, nil, err
  82. }
  83. // 处理Auth
  84. token := getRequestToken(r)
  85. xc.Auth(token)
  86. if mediaType == gin.MIMEMultipartPOSTForm {
  87. formValues, formFile, err := MultipartRequest2RpcxRequest(r)
  88. formValues["__AUTH"] = token
  89. conn, callErr := xc.Stream(context.Background(), formValues)
  90. if callErr != nil {
  91. return nil, nil, err
  92. }
  93. //判断是否需要传输文件
  94. if len(formFile) > 0 {
  95. //发送文件
  96. err = sendFile(conn, formFile)
  97. if err != nil {
  98. return nil, nil, err
  99. }
  100. }
  101. //获取反馈结果
  102. resp, err := io.ReadAll(conn)
  103. if err != nil {
  104. return nil, nil, err
  105. }
  106. conn.Close()
  107. return formValues, resp, err
  108. } else {
  109. req, err := HttpRequest2RpcxRequest(r)
  110. if err != nil {
  111. return nil, nil, err
  112. }
  113. seq := atomic.AddUint64(&g.seq, 1)
  114. req.SetSeq(seq)
  115. return xc.SendRaw(context.Background(), req)
  116. }
  117. }
  118. func getXClient(g *Gateway, servicePath string) (xc client.XClient, err error) {
  119. defer func() {
  120. if e := recover(); e != nil {
  121. if ee, ok := e.(error); ok {
  122. err = ee
  123. return
  124. }
  125. err = fmt.Errorf("failed to get xclient: %v", e)
  126. }
  127. }()
  128. if g.xclients[servicePath] == nil {
  129. d, err := g.serviceDiscovery.Clone(servicePath)
  130. if err != nil {
  131. return nil, err
  132. }
  133. g.xclients[servicePath] = client.NewXClient(servicePath, g.FailMode, g.SelectMode, d, g.Option)
  134. }
  135. xc = g.xclients[servicePath]
  136. return xc, err
  137. }
  138. // 解析token,若无,返回空
  139. func getRequestToken(r *http.Request) string {
  140. authHeader := r.Header.Get("Authorization")
  141. if authHeader != "" {
  142. parts := strings.SplitN(authHeader, " ", 2)
  143. if !(len(parts) == 2 && parts[0] == "Bearer") {
  144. //glog.Warning("authHeader:" + authHeader + " get token key fail")
  145. return ""
  146. } else if parts[1] == "" {
  147. //glog.Warning("authHeader:" + authHeader + " get token fail")
  148. return ""
  149. }
  150. return parts[1]
  151. }
  152. return ""
  153. }
  154. // sendFile
  155. func sendFile(conn net.Conn, files map[string][]*multipart.FileHeader) error {
  156. index := 0
  157. for key, header := range files {
  158. index++
  159. file, _ := header[0].Open()
  160. fileName := header[0].Filename
  161. fileSize := header[0].Size
  162. fileHeader := fmt.Sprintf("%s %s %v", key, fileName, fileSize)
  163. // 发送文件名长度
  164. length := strconv.Itoa(len(fileHeader))
  165. conn.Write([]byte(PadLeft(length, 3, "0")))
  166. // 发送文件名和文件长度给 接收端
  167. conn.Write([]byte(fileHeader))
  168. // 从本文件中,读数据,写给网络接收端。
  169. buf := make([]byte, 1024)
  170. for {
  171. n, err := file.Read(buf)
  172. if n == 0 {
  173. log.Debug("发送文件完成")
  174. break
  175. }
  176. // 写到网络socket中
  177. _, err = conn.Write(buf[:n])
  178. if err != nil {
  179. log.Debug("conn.Write err:", err)
  180. break
  181. }
  182. }
  183. if index == len(files) {
  184. conn.Write([]byte("2"))
  185. } else {
  186. conn.Write([]byte("1"))
  187. }
  188. }
  189. return nil
  190. }
  191. func PadLeft(s string, length int, padding string) string {
  192. sLen := len(s)
  193. if sLen >= length {
  194. return s
  195. }
  196. padCount := length - sLen
  197. return strings.Repeat(padding, padCount) + s
  198. }