gateway.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. package gateway
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "github.com/gin-gonic/gin"
  7. "mime"
  8. "net/http"
  9. "os"
  10. "strings"
  11. "sync"
  12. "sync/atomic"
  13. "github.com/smallnest/rpcx/client"
  14. )
  15. // ServiceHandler converts http.Request into rpcx.Request and send it to rpcx service,
  16. // and then converts the result and writes it into http.Response.
  17. // You should get the http.Request and servicePath in your web handler.
  18. type ServiceHandler func(*http.Request, string) (map[string]string, []byte, error)
  19. // HTTPServer is a golang web interface。
  20. // You can use echo, gin, iris or other go web frameworks to implement it.
  21. // You must wrap ServiceHandler into your handler of your selected web framework and add it into router.
  22. type HTTPServer interface {
  23. RegisterHandler(base string, handler ServiceHandler)
  24. Serve() error
  25. }
  26. // Gateway is a rpcx gateway which can convert http invoke into rpcx invoke.
  27. type Gateway struct {
  28. base string
  29. httpserver HTTPServer
  30. serviceDiscovery client.ServiceDiscovery
  31. FailMode client.FailMode
  32. SelectMode client.SelectMode
  33. Option client.Option
  34. mu sync.RWMutex
  35. xclients map[string]client.XClient
  36. seq uint64
  37. }
  38. // NewGateway returns a new gateway.
  39. func NewGateway(base string, hs HTTPServer, sd client.ServiceDiscovery, failMode client.FailMode, selectMode client.SelectMode, option client.Option) *Gateway {
  40. // base is empty or like /abc/
  41. if base == "" {
  42. base = "/"
  43. }
  44. if base[0] != '/' {
  45. base = "/" + base
  46. }
  47. g := &Gateway{
  48. base: base,
  49. httpserver: hs,
  50. serviceDiscovery: sd,
  51. FailMode: failMode,
  52. SelectMode: selectMode,
  53. Option: option,
  54. xclients: make(map[string]client.XClient),
  55. }
  56. hs.RegisterHandler(base, g.handler)
  57. return g
  58. }
  59. // Serve listens on the TCP network address addr and then calls
  60. // Serve with handler to handle requests on incoming connections.
  61. // Accepted connections are configured to enable TCP keep-alives.
  62. func (g *Gateway) Serve() error {
  63. return g.httpserver.Serve()
  64. }
  65. func (g *Gateway) handler(r *http.Request, servicePath string) (meta map[string]string, payload []byte, err error) {
  66. var xc client.XClient
  67. g.mu.Lock()
  68. xc, err = getXClient(g, servicePath)
  69. g.mu.Unlock()
  70. if err != nil {
  71. return nil, nil, err
  72. }
  73. // 处理Auth
  74. token := getRequestToken(r)
  75. if token != "" {
  76. xc.Auth(token)
  77. }
  78. contentType := r.Header.Get("Content-Type")
  79. mediaType, _, err := mime.ParseMediaType(contentType)
  80. if mediaType == gin.MIMEMultipartPOSTForm {
  81. req, fileName, err := MultipartRequest2RpcxRequest(r)
  82. formMeta := make(map[string]string)
  83. err = json.Unmarshal(req.Payload, formMeta)
  84. defer os.Remove(fileName)
  85. err = xc.SendFile(context.Background(), fileName, 0, formMeta)
  86. if err != nil {
  87. return nil, nil, err
  88. }
  89. resp := make(map[string]interface{})
  90. resp["code"] = 200
  91. resp["msg"] = "提交成功!"
  92. payload, err = json.Marshal(resp)
  93. return req.Metadata, payload, err
  94. } else {
  95. req, err := HttpRequest2RpcxRequest(r)
  96. if err != nil {
  97. return nil, nil, err
  98. }
  99. seq := atomic.AddUint64(&g.seq, 1)
  100. req.SetSeq(seq)
  101. return xc.SendRaw(context.Background(), req)
  102. }
  103. }
  104. func getXClient(g *Gateway, servicePath string) (xc client.XClient, err error) {
  105. defer func() {
  106. if e := recover(); e != nil {
  107. if ee, ok := e.(error); ok {
  108. err = ee
  109. return
  110. }
  111. err = fmt.Errorf("failed to get xclient: %v", e)
  112. }
  113. }()
  114. if g.xclients[servicePath] == nil {
  115. d, err := g.serviceDiscovery.Clone(servicePath)
  116. if err != nil {
  117. return nil, err
  118. }
  119. g.xclients[servicePath] = client.NewXClient(servicePath, g.FailMode, g.SelectMode, d, g.Option)
  120. }
  121. xc = g.xclients[servicePath]
  122. return xc, err
  123. }
  124. // 解析token,若无,返回空
  125. func getRequestToken(r *http.Request) string {
  126. authHeader := r.Header.Get("Authorization")
  127. if authHeader != "" {
  128. parts := strings.SplitN(authHeader, " ", 2)
  129. if !(len(parts) == 2 && parts[0] == "Bearer") {
  130. //glog.Warning("authHeader:" + authHeader + " get token key fail")
  131. return ""
  132. } else if parts[1] == "" {
  133. //glog.Warning("authHeader:" + authHeader + " get token fail")
  134. return ""
  135. }
  136. return parts[1]
  137. }
  138. return ""
  139. }