package gateway import ( "context" "fmt" "github.com/gin-gonic/gin" "github.com/smallnest/rpcx/log" "github.com/smallnest/rpcx/share" "io" "mime" "mime/multipart" "net" "net/http" "strconv" "strings" "sync" "sync/atomic" "github.com/smallnest/rpcx/client" ) // ServiceHandler converts http.Request into rpcx.Request and send it to rpcx service, // and then converts the result and writes it into http.Response. // You should get the http.Request and servicePath in your web handler. type ServiceHandler func(*http.Request, string) (map[string]string, []byte, error) // HTTPServer is a golang web interface。 // You can use echo, gin, iris or other go web frameworks to implement it. // You must wrap ServiceHandler into your handler of your selected web framework and add it into router. type HTTPServer interface { RegisterHandler(base string, handler ServiceHandler) Serve() error } // Gateway is a rpcx gateway which can convert http invoke into rpcx invoke. type Gateway struct { base string httpserver HTTPServer serviceDiscovery client.ServiceDiscovery FailMode client.FailMode SelectMode client.SelectMode Option client.Option mu sync.RWMutex xclients map[string]client.XClient seq uint64 } // NewGateway returns a new gateway. func NewGateway(base string, hs HTTPServer, sd client.ServiceDiscovery, failMode client.FailMode, selectMode client.SelectMode, option client.Option) *Gateway { // base is empty or like /abc/ if base == "" { base = "/" } if base[0] != '/' { base = "/" + base } g := &Gateway{ base: base, httpserver: hs, serviceDiscovery: sd, FailMode: failMode, SelectMode: selectMode, Option: option, xclients: make(map[string]client.XClient), } hs.RegisterHandler(base, g.handler) return g } // Serve listens on the TCP network address addr and then calls // Serve with handler to handle requests on incoming connections. // Accepted connections are configured to enable TCP keep-alives. func (g *Gateway) Serve() error { return g.httpserver.Serve() } func (g *Gateway) handler(r *http.Request, servicePath string) (meta map[string]string, payload []byte, err error) { contentType := r.Header.Get("Content-Type") mediaType, _, err := mime.ParseMediaType(contentType) var xc client.XClient g.mu.Lock() if mediaType == gin.MIMEMultipartPOSTForm { xc, err = getXClient(g, share.StreamServiceName) } else { xc, err = getXClient(g, servicePath) } g.mu.Unlock() if err != nil { return nil, nil, err } // 处理Auth token := getRequestToken(r) xc.Auth(token) if mediaType == gin.MIMEMultipartPOSTForm { formValues, formFile, err := MultipartRequest2RpcxRequest(r) formValues["__AUTH"] = token conn, callErr := xc.Stream(context.Background(), formValues) if callErr != nil { return nil, nil, err } //判断是否需要传输文件 if len(formFile) > 0 { //发送文件 err = sendFile(conn, formFile) if err != nil { return nil, nil, err } } //获取反馈结果 resp, err := io.ReadAll(conn) if err != nil { return nil, nil, err } conn.Close() return formValues, resp, err } else { req, err := HttpRequest2RpcxRequest(r) if err != nil { return nil, nil, err } seq := atomic.AddUint64(&g.seq, 1) req.SetSeq(seq) return xc.SendRaw(context.Background(), req) } } func getXClient(g *Gateway, servicePath string) (xc client.XClient, err error) { defer func() { if e := recover(); e != nil { if ee, ok := e.(error); ok { err = ee return } err = fmt.Errorf("failed to get xclient: %v", e) } }() if g.xclients[servicePath] == nil { d, err := g.serviceDiscovery.Clone(servicePath) if err != nil { return nil, err } g.xclients[servicePath] = client.NewXClient(servicePath, g.FailMode, g.SelectMode, d, g.Option) } xc = g.xclients[servicePath] return xc, err } // 解析token,若无,返回空 func getRequestToken(r *http.Request) string { authHeader := r.Header.Get("Authorization") if authHeader != "" { parts := strings.SplitN(authHeader, " ", 2) if !(len(parts) == 2 && parts[0] == "Bearer") { //glog.Warning("authHeader:" + authHeader + " get token key fail") return "" } else if parts[1] == "" { //glog.Warning("authHeader:" + authHeader + " get token fail") return "" } return parts[1] } return "" } // sendFile func sendFile(conn net.Conn, files map[string][]*multipart.FileHeader) error { index := 0 for key, header := range files { index++ file, _ := header[0].Open() fileName := header[0].Filename fileSize := header[0].Size fileHeader := fmt.Sprintf("%s %s %v", key, fileName, fileSize) // 发送文件名长度 length := strconv.Itoa(len(fileHeader)) conn.Write([]byte(PadLeft(length, 3, "0"))) // 发送文件名和文件长度给 接收端 conn.Write([]byte(fileHeader)) // 从本文件中,读数据,写给网络接收端。 buf := make([]byte, 1024) for { n, err := file.Read(buf) if n == 0 { log.Debug("发送文件完成") break } // 写到网络socket中 _, err = conn.Write(buf[:n]) if err != nil { log.Debug("conn.Write err:", err) break } } if index == len(files) { conn.Write([]byte("2")) } else { conn.Write([]byte("1")) } } return nil } func PadLeft(s string, length int, padding string) string { sLen := len(s) if sLen >= length { return s } padCount := length - sLen return strings.Repeat(padding, padCount) + s }