gateway.go 4.7 KB


  1. package gateway
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/smallnest/rpcx/log"
  7. "github.com/smallnest/rpcx/share"
  8. "io"
  9. "mime"
  10. "mime/multipart"
  11. "net"
  12. "net/http"
  13. "strings"
  14. "sync"
  15. "sync/atomic"
  16. "github.com/smallnest/rpcx/client"
  17. )
  18. // ServiceHandler converts http.Request into rpcx.Request and send it to rpcx service,
  19. // and then converts the result and writes it into http.Response.
  20. // You should get the http.Request and servicePath in your web handler.
  21. type ServiceHandler func(*http.Request, string) (map[string]string, []byte, error)
  22. // HTTPServer is a golang web interface。
  23. // You can use echo, gin, iris or other go web frameworks to implement it.
  24. // You must wrap ServiceHandler into your handler of your selected web framework and add it into router.
  25. type HTTPServer interface {
  26. RegisterHandler(base string, handler ServiceHandler)
  27. Serve() error
  28. }
  29. // Gateway is a rpcx gateway which can convert http invoke into rpcx invoke.
  30. type Gateway struct {
  31. base string
  32. httpserver HTTPServer
  33. serviceDiscovery client.ServiceDiscovery
  34. FailMode client.FailMode
  35. SelectMode client.SelectMode
  36. Option client.Option
  37. mu sync.RWMutex
  38. xclients map[string]client.XClient
  39. seq uint64
  40. }
  41. // NewGateway returns a new gateway.
  42. func NewGateway(base string, hs HTTPServer, sd client.ServiceDiscovery, failMode client.FailMode, selectMode client.SelectMode, option client.Option) *Gateway {
  43. // base is empty or like /abc/
  44. if base == "" {
  45. base = "/"
  46. }
  47. if base[0] != '/' {
  48. base = "/" + base
  49. }
  50. g := &Gateway{
  51. base: base,
  52. httpserver: hs,
  53. serviceDiscovery: sd,
  54. FailMode: failMode,
  55. SelectMode: selectMode,
  56. Option: option,
  57. xclients: make(map[string]client.XClient),
  58. }
  59. hs.RegisterHandler(base, g.handler)
  60. return g
  61. }
  62. // Serve listens on the TCP network address addr and then calls
  63. // Serve with handler to handle requests on incoming connections.
  64. // Accepted connections are configured to enable TCP keep-alives.
  65. func (g *Gateway) Serve() error {
  66. return g.httpserver.Serve()
  67. }
  68. func (g *Gateway) handler(r *http.Request, servicePath string) (meta map[string]string, payload []byte, err error) {
  69. contentType := r.Header.Get("Content-Type")
  70. mediaType, _, err := mime.ParseMediaType(contentType)
  71. var xc client.XClient
  72. g.mu.Lock()
  73. if mediaType == gin.MIMEMultipartPOSTForm {
  74. xc, err = getXClient(g, share.StreamServiceName)
  75. } else {
  76. xc, err = getXClient(g, servicePath)
  77. }
  78. g.mu.Unlock()
  79. if err != nil {
  80. return nil, nil, err
  81. }
  82. // 处理Auth
  83. token := getRequestToken(r)
  84. if token != "" {
  85. xc.Auth(token)
  86. }
  87. if mediaType == gin.MIMEMultipartPOSTForm {
  88. formValues, formFile, err := MultipartRequest2RpcxRequest(r)
  89. formValues["__AUTH"] = token
  90. conn, callErr := xc.Stream(context.Background(), formValues)
  91. if callErr != nil {
  92. return nil, nil, err
  93. }
  94. //发送文件
  95. err = sendFile(conn, formFile)
  96. if err != nil {
  97. return nil, nil, err
  98. }
  99. //获取反馈结果
  100. resp, err := io.ReadAll(conn)
  101. if err != nil {
  102. return nil, nil, err
  103. }
  104. conn.Close()
  105. return formValues, resp, err
  106. } else {
  107. req, err := HttpRequest2RpcxRequest(r)
  108. if err != nil {
  109. return nil, nil, err
  110. }
  111. seq := atomic.AddUint64(&g.seq, 1)
  112. req.SetSeq(seq)
  113. return xc.SendRaw(context.Background(), req)
  114. }
  115. }
  116. func getXClient(g *Gateway, servicePath string) (xc client.XClient, err error) {
  117. defer func() {
  118. if e := recover(); e != nil {
  119. if ee, ok := e.(error); ok {
  120. err = ee
  121. return
  122. }
  123. err = fmt.Errorf("failed to get xclient: %v", e)
  124. }
  125. }()
  126. if g.xclients[servicePath] == nil {
  127. d, err := g.serviceDiscovery.Clone(servicePath)
  128. if err != nil {
  129. return nil, err
  130. }
  131. g.xclients[servicePath] = client.NewXClient(servicePath, g.FailMode, g.SelectMode, d, g.Option)
  132. }
  133. xc = g.xclients[servicePath]
  134. return xc, err
  135. }
  136. // 解析token,若无,返回空
  137. func getRequestToken(r *http.Request) string {
  138. authHeader := r.Header.Get("Authorization")
  139. if authHeader != "" {
  140. parts := strings.SplitN(authHeader, " ", 2)
  141. if !(len(parts) == 2 && parts[0] == "Bearer") {
  142. //glog.Warning("authHeader:" + authHeader + " get token key fail")
  143. return ""
  144. } else if parts[1] == "" {
  145. //glog.Warning("authHeader:" + authHeader + " get token fail")
  146. return ""
  147. }
  148. return parts[1]
  149. }
  150. return ""
  151. }
  152. // sendFile
  153. func sendFile(conn net.Conn, file *multipart.FileHeader) error {
  154. // 只读打开文件
  155. f, err := file.Open()
  156. if err != nil {
  157. return err
  158. }
  159. // 从本文件中,读数据,写给网络接收端。
  160. buf := make([]byte, 1024)
  161. for {
  162. n, err := f.Read(buf)
  163. if n == 0 {
  164. log.Debug("发送文件完成")
  165. return nil
  166. }
  167. // 写到网络socket中
  168. _, err = conn.Write(buf[:n])
  169. if err != nil {
  170. log.Debug("conn.Write err:", err)
  171. return nil
  172. }
  173. }
  174. }