gateway.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package gateway
  2. import (
  3. "context"
  4. "fmt"
  5. "github.com/gin-gonic/gin"
  6. "github.com/smallnest/rpcx/log"
  7. "github.com/smallnest/rpcx/share"
  8. "io"
  9. "mime"
  10. "mime/multipart"
  11. "net"
  12. "net/http"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "sync/atomic"
  17. "github.com/smallnest/rpcx/client"
  18. )
  19. // ServiceHandler converts http.Request into rpcx.Request and send it to rpcx service,
  20. // and then converts the result and writes it into http.Response.
  21. // You should get the http.Request and servicePath in your web handler.
  22. type ServiceHandler func(*http.Request, string) (map[string]string, []byte, error)
  23. // HTTPServer is a golang web interface。
  24. // You can use echo, gin, iris or other go web frameworks to implement it.
  25. // You must wrap ServiceHandler into your handler of your selected web framework and add it into router.
  26. type HTTPServer interface {
  27. RegisterHandler(base string, handler ServiceHandler)
  28. Serve() error
  29. }
  30. // Gateway is a rpcx gateway which can convert http invoke into rpcx invoke.
  31. type Gateway struct {
  32. base string
  33. httpserver HTTPServer
  34. serviceDiscovery client.ServiceDiscovery
  35. FailMode client.FailMode
  36. SelectMode client.SelectMode
  37. Option client.Option
  38. mu sync.RWMutex
  39. xclients map[string]client.XClient
  40. seq uint64
  41. }
  42. // NewGateway returns a new gateway.
  43. func NewGateway(base string, hs HTTPServer, sd client.ServiceDiscovery, failMode client.FailMode, selectMode client.SelectMode, option client.Option) *Gateway {
  44. // base is empty or like /abc/
  45. if base == "" {
  46. base = "/"
  47. }
  48. if base[0] != '/' {
  49. base = "/" + base
  50. }
  51. g := &Gateway{
  52. base: base,
  53. httpserver: hs,
  54. serviceDiscovery: sd,
  55. FailMode: failMode,
  56. SelectMode: selectMode,
  57. Option: option,
  58. xclients: make(map[string]client.XClient),
  59. }
  60. hs.RegisterHandler(base, g.handler)
  61. return g
  62. }
  63. // Serve listens on the TCP network address addr and then calls
  64. // Serve with handler to handle requests on incoming connections.
  65. // Accepted connections are configured to enable TCP keep-alives.
  66. func (g *Gateway) Serve() error {
  67. return g.httpserver.Serve()
  68. }
  69. func (g *Gateway) handler(r *http.Request, servicePath string) (meta map[string]string, payload []byte, err error) {
  70. contentType := r.Header.Get("Content-Type")
  71. mediaType, _, err := mime.ParseMediaType(contentType)
  72. var xc client.XClient
  73. g.mu.Lock()
  74. if mediaType == gin.MIMEMultipartPOSTForm {
  75. xc, err = getXClient(g, share.StreamServiceName)
  76. } else {
  77. xc, err = getXClient(g, servicePath)
  78. }
  79. g.mu.Unlock()
  80. if err != nil {
  81. return nil, nil, err
  82. }
  83. // 处理Auth
  84. token := getRequestToken(r)
  85. xc.Auth(token)
  86. if mediaType == gin.MIMEMultipartPOSTForm {
  87. formValues, formFile, err := MultipartRequest2RpcxRequest(r)
  88. formValues["__AUTH"] = token
  89. conn, callErr := xc.Stream(context.Background(), formValues)
  90. if callErr != nil {
  91. return nil, nil, err
  92. }
  93. //发送文件
  94. err = sendFile(conn, formFile)
  95. if err != nil {
  96. return nil, nil, err
  97. }
  98. //获取反馈结果
  99. resp, err := io.ReadAll(conn)
  100. if err != nil {
  101. return nil, nil, err
  102. }
  103. conn.Close()
  104. return formValues, resp, err
  105. } else {
  106. req, err := HttpRequest2RpcxRequest(r)
  107. if err != nil {
  108. return nil, nil, err
  109. }
  110. seq := atomic.AddUint64(&g.seq, 1)
  111. req.SetSeq(seq)
  112. return xc.SendRaw(context.Background(), req)
  113. }
  114. }
  115. func getXClient(g *Gateway, servicePath string) (xc client.XClient, err error) {
  116. defer func() {
  117. if e := recover(); e != nil {
  118. if ee, ok := e.(error); ok {
  119. err = ee
  120. return
  121. }
  122. err = fmt.Errorf("failed to get xclient: %v", e)
  123. }
  124. }()
  125. if g.xclients[servicePath] == nil {
  126. d, err := g.serviceDiscovery.Clone(servicePath)
  127. if err != nil {
  128. return nil, err
  129. }
  130. g.xclients[servicePath] = client.NewXClient(servicePath, g.FailMode, g.SelectMode, d, g.Option)
  131. }
  132. xc = g.xclients[servicePath]
  133. return xc, err
  134. }
  135. // 解析token,若无,返回空
  136. func getRequestToken(r *http.Request) string {
  137. authHeader := r.Header.Get("Authorization")
  138. if authHeader != "" {
  139. parts := strings.SplitN(authHeader, " ", 2)
  140. if !(len(parts) == 2 && parts[0] == "Bearer") {
  141. //glog.Warning("authHeader:" + authHeader + " get token key fail")
  142. return ""
  143. } else if parts[1] == "" {
  144. //glog.Warning("authHeader:" + authHeader + " get token fail")
  145. return ""
  146. }
  147. return parts[1]
  148. }
  149. return ""
  150. }
  151. // sendFile
  152. func sendFile(conn net.Conn, files map[string][]*multipart.FileHeader) error {
  153. index := 0
  154. for key, header := range files {
  155. index++
  156. file, _ := header[0].Open()
  157. fileName := header[0].Filename
  158. fileSize := header[0].Size
  159. fileHeader := fmt.Sprintf("%s %s %v", key, fileName, fileSize)
  160. // 发送文件名长度
  161. length := strconv.Itoa(len(fileHeader))
  162. conn.Write([]byte(PadLeft(length, 3, "0")))
  163. // 发送文件名和文件长度给 接收端
  164. conn.Write([]byte(fileHeader))
  165. // 从本文件中,读数据,写给网络接收端。
  166. buf := make([]byte, 1024)
  167. for {
  168. n, err := file.Read(buf)
  169. if n == 0 {
  170. log.Debug("发送文件完成")
  171. break
  172. }
  173. // 写到网络socket中
  174. _, err = conn.Write(buf[:n])
  175. if err != nil {
  176. log.Debug("conn.Write err:", err)
  177. break
  178. }
  179. }
  180. if index == len(files) {
  181. conn.Write([]byte("2"))
  182. } else {
  183. conn.Write([]byte("1"))
  184. }
  185. }
  186. return nil
  187. }
  188. func PadLeft(s string, length int, padding string) string {
  189. sLen := len(s)
  190. if sLen >= length {
  191. return s
  192. }
  193. padCount := length - sLen
  194. return strings.Repeat(padding, padCount) + s
  195. }