| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- package gateway
- import (
- "context"
- "fmt"
- "github.com/gin-gonic/gin"
- "github.com/smallnest/rpcx/log"
- "github.com/smallnest/rpcx/share"
- "io"
- "mime"
- "mime/multipart"
- "net"
- "net/http"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "github.com/smallnest/rpcx/client"
- )
- // ServiceHandler converts http.Request into rpcx.Request and send it to rpcx service,
- // and then converts the result and writes it into http.Response.
- // You should get the http.Request and servicePath in your web handler.
- type ServiceHandler func(*http.Request, string) (map[string]string, []byte, error)
- // HTTPServer is a golang web interface。
- // You can use echo, gin, iris or other go web frameworks to implement it.
- // You must wrap ServiceHandler into your handler of your selected web framework and add it into router.
- type HTTPServer interface {
- RegisterHandler(base string, handler ServiceHandler)
- Serve() error
- }
- // Gateway is a rpcx gateway which can convert http invoke into rpcx invoke.
- type Gateway struct {
- base string
- httpserver HTTPServer
- serviceDiscovery client.ServiceDiscovery
- FailMode client.FailMode
- SelectMode client.SelectMode
- Option client.Option
- mu sync.RWMutex
- xclients map[string]client.XClient
- seq uint64
- }
- // NewGateway returns a new gateway.
- func NewGateway(base string, hs HTTPServer, sd client.ServiceDiscovery, failMode client.FailMode, selectMode client.SelectMode, option client.Option) *Gateway {
- // base is empty or like /abc/
- if base == "" {
- base = "/"
- }
- if base[0] != '/' {
- base = "/" + base
- }
- g := &Gateway{
- base: base,
- httpserver: hs,
- serviceDiscovery: sd,
- FailMode: failMode,
- SelectMode: selectMode,
- Option: option,
- xclients: make(map[string]client.XClient),
- }
- hs.RegisterHandler(base, g.handler)
- return g
- }
- // Serve listens on the TCP network address addr and then calls
- // Serve with handler to handle requests on incoming connections.
- // Accepted connections are configured to enable TCP keep-alives.
- func (g *Gateway) Serve() error {
- return g.httpserver.Serve()
- }
- func (g *Gateway) handler(r *http.Request, servicePath string) (meta map[string]string, payload []byte, err error) {
- contentType := r.Header.Get("Content-Type")
- mediaType, _, err := mime.ParseMediaType(contentType)
- var xc client.XClient
- g.mu.Lock()
- if mediaType == gin.MIMEMultipartPOSTForm {
- xc, err = getXClient(g, share.StreamServiceName)
- } else {
- xc, err = getXClient(g, servicePath)
- }
- g.mu.Unlock()
- if err != nil {
- return nil, nil, err
- }
- // 处理Auth
- token := getRequestToken(r)
- xc.Auth(token)
- if mediaType == gin.MIMEMultipartPOSTForm {
- formValues, formFile, err := MultipartRequest2RpcxRequest(r)
- formValues["__AUTH"] = token
- conn, callErr := xc.Stream(context.Background(), formValues)
- if callErr != nil {
- return nil, nil, err
- }
- //判断是否需要传输文件
- if len(formFile) > 0 {
- //发送文件
- err = sendFile(conn, formFile)
- if err != nil {
- return nil, nil, err
- }
- }
- //获取反馈结果
- resp, err := io.ReadAll(conn)
- if err != nil {
- return nil, nil, err
- }
- conn.Close()
- return formValues, resp, err
- } else {
- req, err := HttpRequest2RpcxRequest(r)
- if err != nil {
- return nil, nil, err
- }
- seq := atomic.AddUint64(&g.seq, 1)
- req.SetSeq(seq)
- return xc.SendRaw(context.Background(), req)
- }
- }
- func getXClient(g *Gateway, servicePath string) (xc client.XClient, err error) {
- defer func() {
- if e := recover(); e != nil {
- if ee, ok := e.(error); ok {
- err = ee
- return
- }
- err = fmt.Errorf("failed to get xclient: %v", e)
- }
- }()
- if g.xclients[servicePath] == nil {
- d, err := g.serviceDiscovery.Clone(servicePath)
- if err != nil {
- return nil, err
- }
- g.xclients[servicePath] = client.NewXClient(servicePath, g.FailMode, g.SelectMode, d, g.Option)
- }
- xc = g.xclients[servicePath]
- return xc, err
- }
- // 解析token,若无,返回空
- func getRequestToken(r *http.Request) string {
- authHeader := r.Header.Get("Authorization")
- if authHeader != "" {
- parts := strings.SplitN(authHeader, " ", 2)
- if !(len(parts) == 2 && parts[0] == "Bearer") {
- //glog.Warning("authHeader:" + authHeader + " get token key fail")
- return ""
- } else if parts[1] == "" {
- //glog.Warning("authHeader:" + authHeader + " get token fail")
- return ""
- }
- return parts[1]
- }
- return ""
- }
- // sendFile
- func sendFile(conn net.Conn, files map[string][]*multipart.FileHeader) error {
- index := 0
- for key, header := range files {
- index++
- file, _ := header[0].Open()
- fileName := header[0].Filename
- fileSize := header[0].Size
- fileHeader := fmt.Sprintf("%s %s %v", key, fileName, fileSize)
- // 发送文件名长度
- length := strconv.Itoa(len(fileHeader))
- conn.Write([]byte(PadLeft(length, 3, "0")))
- // 发送文件名和文件长度给 接收端
- conn.Write([]byte(fileHeader))
- // 从本文件中,读数据,写给网络接收端。
- buf := make([]byte, 1024)
- for {
- n, err := file.Read(buf)
- if n == 0 {
- log.Debug("发送文件完成")
- break
- }
- // 写到网络socket中
- _, err = conn.Write(buf[:n])
- if err != nil {
- log.Debug("conn.Write err:", err)
- break
- }
- }
- if index == len(files) {
- conn.Write([]byte("2"))
- } else {
- conn.Write([]byte("1"))
- }
- }
- return nil
- }
- func PadLeft(s string, length int, padding string) string {
- sLen := len(s)
- if sLen >= length {
- return s
- }
- padCount := length - sLen
- return strings.Repeat(padding, padCount) + s
- }
|