Ver Fonte

feat:适配文件上传

Cheng Jian há 2 anos atrás
pai
commit
4ec5b05818
4 ficheiros alterados com 86 adições e 123 exclusões
  1. 8 7
      micro_gateway/gin/server.go
  2. 3 14
      micro_gateway/main.go
  3. 23 86
      rpcx-gateway/converter.go
  4. 52 16
      rpcx-gateway/gateway.go

+ 8 - 7
micro_gateway/gin/server.go

@@ -53,10 +53,10 @@ func (s *Server) RegisterHandler(base string, handler ServiceHandler) {
 		// 自定义日志输出
 		g.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
 			return fmt.Sprintf("%s [%s] %s%s %s\n",
-				param.ClientIP,  // 客户端IP
+				param.ClientIP,                     // 客户端IP
 				param.Request.Header.Get("Tenant"), // 租户码
-				param.Path,  // 请求路径
-				"/" + param.Request.Header.Get("X-RPCX-ServicePath")+"/"+param.Request.Header.Get("X-RPCX-ServiceMethod"),
+				param.Path,                         // 请求路径
+				"/"+param.Request.Header.Get("X-RPCX-ServicePath")+"/"+param.Request.Header.Get("X-RPCX-ServiceMethod"),
 				//param.StatusCode,  // 请求状态码
 				//param.Latency,  // 请求时长
 				param.ErrorMessage,
@@ -106,7 +106,7 @@ func wrapServiceHandler(handler ServiceHandler) gin.HandlerFunc {
 		tenant := r.Header.Get("Tenant")
 		if tenant != "" {
 			//r.Header.Set(XMeta, "tenant="+tenant)
-			xmeta = "tenant="+tenant
+			xmeta = "tenant=" + tenant
 		}
 
 		// 传递ClientIP和UserAgent
@@ -114,8 +114,8 @@ func wrapServiceHandler(handler ServiceHandler) gin.HandlerFunc {
 		if x_Meta != "" && strings.Contains(x_Meta, "need_clint_Info=1") {
 			clientIP := ctx.ClientIP()
 			userAgent := gbase64.EncodeString(r.UserAgent())
-			if xmeta!=""{
-				xmeta= xmeta + "&"
+			if xmeta != "" {
+				xmeta = xmeta + "&"
 			}
 			xmeta = xmeta + "clientIP=" + clientIP + "&userAgent=" + userAgent
 		}
@@ -182,6 +182,7 @@ func wrapServiceHandler(handler ServiceHandler) gin.HandlerFunc {
 		//wh.Set(XMessageStatusType, "Error")
 		//wh.Set(XErrorMessage, err.Error())
 		//ctx.String(http.StatusOK, err.Error())
+		//resp := errorJson(500, err.Error())
 		resp := errorJson(500, err.Error())
 		if err.Error() == "InvalidToken" {
 			resp.Code = 401
@@ -194,7 +195,7 @@ func (s *Server) Serve() error {
 	return s.g.Run(s.addr)
 }
 
-// 数据返回通用JSON数据结构
+// JsonResponse 数据返回通用JSON数据结构
 type JsonResponse struct {
 	Code int         `json:"code,omitempty"` // 错误码((200:成功, 其他是异常)
 	Msg  string      `json:"msg,omitempty"`  // 提示信息

+ 3 - 14
micro_gateway/main.go

@@ -22,11 +22,6 @@ var (
 	allowKeyNotFound = flag.Bool("key", true, "key, allow key not found")
 )
 
-type ServiceDiscovery struct {
-	BasePath     string
-	SrvDiscovery client.ServiceDiscovery
-}
-
 func main() {
 	flag.Parse()
 
@@ -37,19 +32,15 @@ func main() {
 
 	httpServer := gin.New(*addr)
 
-	var gws map[string]*gateway.Gateway
-	gws = make(map[string]*gateway.Gateway)
-
 	for key, value := range discoverys {
 		if strings.HasPrefix(key, "/") {
 			key = key[1:]
 		}
 		srvPath := "/" + key
-		gws[srvPath] = gateway.NewGateway(srvPath, httpServer, value, client.FailMode(*failmode), client.SelectMode(*selectMode), client.DefaultOption)
-
+		gateway.NewGateway(srvPath, httpServer, value, client.FailMode(*failmode), client.SelectMode(*selectMode), client.DefaultOption)
 	}
 
-	// 启动服务
+	// 启动服务Co
 	if err := httpServer.Serve(); err != nil {
 		log.Fatal(err)
 	}
@@ -57,8 +48,7 @@ func main() {
 }
 
 func createServiceDiscovery() (map[string]client.ServiceDiscovery, error) {
-	var serviceDiscoverys map[string]client.ServiceDiscovery
-	serviceDiscoverys = make(map[string]client.ServiceDiscovery)
+	serviceDiscoverys := make(map[string]client.ServiceDiscovery)
 
 	regAddr := *registry
 	i := strings.Index(regAddr, "://")
@@ -84,7 +74,6 @@ func createServiceDiscovery() (map[string]client.ServiceDiscovery, error) {
 			}
 			serviceDiscoverys[path] = discovery
 		}
-		//serviceDiscoverys["dashoo.biobank.adapter-0.1"] = discovery
 		//return serviceDiscoverys, err
 		return serviceDiscoverys, nil
 	case "consul":

+ 23 - 86
rpcx-gateway/converter.go

@@ -1,14 +1,11 @@
 package gateway
 
 import (
-	"encoding/json"
 	"errors"
 	"io/ioutil"
-	"log"
+	"mime/multipart"
 	"net/http"
 	"net/url"
-	"os"
-	"path"
 	"strconv"
 
 	"github.com/smallnest/rpcx/protocol"
@@ -106,110 +103,50 @@ func HttpRequest2RpcxRequest(r *http.Request) (*protocol.Message, error) {
 	return req, nil
 }
 
-func MultipartRequest2RpcxRequest(r *http.Request) (*protocol.Message, string, error) {
-	req := protocol.NewMessage()
-	req.SetMessageType(protocol.Request)
-
+func MultipartRequest2RpcxRequest(r *http.Request) (map[string]string, *multipart.FileHeader, error) {
+	r.ParseMultipartForm(10 << 20) //10mb
+	form := r.MultipartForm
+	formValues := make(map[string]string)
+	//获取 multi-part/form header的 value
 	h := r.Header
-	seq := h.Get(XMessageID)
-	if seq != "" {
-		id, err := strconv.ParseUint(seq, 10, 64)
-		if err != nil {
-			return nil, "", err
-		}
-		req.SetSeq(id)
-	}
-
-	heartbeat := h.Get(XHeartbeat)
-	if heartbeat != "" {
-		req.SetHeartbeat(true)
-	}
-
-	oneway := h.Get(XOneway)
-	if oneway != "" {
-		req.SetOneway(true)
-	}
-
-	if h.Get("Content-Encoding") == "gzip" {
-		req.SetCompressType(protocol.Gzip)
-	}
-
-	st := h.Get(XSerializeType)
-	if st != "" {
-		rst, err := strconv.Atoi(st)
-		if err != nil {
-			return nil, "", err
-		}
-		req.SetSerializeType(protocol.SerializeType(rst))
-	} else {
-		return nil, "", errors.New("empty serialized type")
-	}
-
 	meta := h.Get(XMeta)
 	if meta != "" {
 		metadata, err := url.ParseQuery(meta)
 		if err != nil {
-			return nil, "", err
+			return nil, nil, err
 		}
-		mm := make(map[string]string)
 		for k, v := range metadata {
 			if len(v) > 0 {
-				mm[k] = v[0]
+				formValues[k] = v[0]
 			}
 		}
-		req.Metadata = mm
 	}
 
+	//获取 multi-part/form body中的form value
+	for k, v := range form.Value {
+		formValues[k] = v[0]
+	}
+
+	file := form.File["file"][0]
+	formValues["fileName"] = file.Filename
+	formValues["fileSize"] = strconv.FormatInt(file.Size, 10)
+
 	sp := h.Get(XServicePath)
 	if sp != "" {
-		req.ServicePath = sp
+		formValues["reqService"] = sp
 	} else {
-		return nil, "", errors.New("empty servicepath")
+		return nil, nil, errors.New("empty servicepath")
 	}
 
 	sm := h.Get(XServiceMethod)
 	if sm != "" {
-		req.ServiceMethod = sm
+		formValues["reqMethod"] = sm
 	} else {
-		return nil, "", errors.New("empty servicemethod")
-	}
-
-	multipartReader, readerErr := r.MultipartReader()
-	if readerErr != nil {
-		return nil, "", readerErr
-	}
-
-	form, parseErr := multipartReader.ReadForm(32 << 20)
-	if parseErr != nil {
-		return nil, "", parseErr
-	}
-
-	metaForm := make(map[string]string)
-	for k, v := range form.Value {
-		metaForm[k] = v[0]
-	}
-	jstr, _ := json.Marshal(metaForm)
-	req.Payload = jstr
-
-	fh := form.File["file"][0]
-	fi, _ := fh.Open()
-	buf, _ := ioutil.ReadAll(fi)
-
-	suffix := path.Ext(fh.Filename)
-
-	tmpFile, err := ioutil.TempFile(os.TempDir(), "multipart-*"+suffix)
-	if err != nil {
-		log.Fatal("Cannot create temporary file", err)
-		return nil, "", err
-	}
-
-	_, err = tmpFile.Write(buf)
-	if err != nil {
-		log.Fatal("Failed to write to temporary file", err)
-		return nil, "", err
+		return nil, nil, errors.New("empty servicemethod")
 	}
+	formValues["authExclude"] = "false"
 
-	return req, tmpFile.Name(), nil
+	return formValues, form.File["file"][0], nil
 }
 
 // func RpcxResponse2HttpResponse(res *protocol.Message) (url.Values, []byte, error) {

+ 52 - 16
rpcx-gateway/gateway.go

@@ -2,12 +2,15 @@ package gateway
 
 import (
 	"context"
-	"encoding/json"
 	"fmt"
 	"github.com/gin-gonic/gin"
+	"github.com/smallnest/rpcx/log"
+	"github.com/smallnest/rpcx/share"
+	"io"
 	"mime"
+	"mime/multipart"
+	"net"
 	"net/http"
-	"os"
 	"strings"
 	"sync"
 	"sync/atomic"
@@ -77,9 +80,16 @@ func (g *Gateway) Serve() error {
 }
 
 func (g *Gateway) handler(r *http.Request, servicePath string) (meta map[string]string, payload []byte, err error) {
+	contentType := r.Header.Get("Content-Type")
+	mediaType, _, err := mime.ParseMediaType(contentType)
+
 	var xc client.XClient
 	g.mu.Lock()
-	xc, err = getXClient(g, servicePath)
+	if mediaType == gin.MIMEMultipartPOSTForm {
+		xc, err = getXClient(g, share.StreamServiceName)
+	} else {
+		xc, err = getXClient(g, servicePath)
+	}
 	g.mu.Unlock()
 
 	if err != nil {
@@ -92,27 +102,29 @@ func (g *Gateway) handler(r *http.Request, servicePath string) (meta map[string]
 		xc.Auth(token)
 	}
 
-	contentType := r.Header.Get("Content-Type")
-	mediaType, _, err := mime.ParseMediaType(contentType)
-
 	if mediaType == gin.MIMEMultipartPOSTForm {
-		req, fileName, err := MultipartRequest2RpcxRequest(r)
-		formMeta := make(map[string]string)
-		err = json.Unmarshal(req.Payload, formMeta)
+		formValues, formFile, err := MultipartRequest2RpcxRequest(r)
+		formValues["__AUTH"] = token
 
-		defer os.Remove(fileName)
+		conn, callErr := xc.Stream(context.Background(), formValues)
+		if callErr != nil {
+			return nil, nil, err
+		}
 
-		err = xc.SendFile(context.Background(), fileName, 0, formMeta)
+		//发送文件
+		err = sendFile(conn, formFile)
+		if err != nil {
+			return nil, nil, err
+		}
 
+		//获取反馈结果
+		resp, err := io.ReadAll(conn)
 		if err != nil {
 			return nil, nil, err
 		}
+		conn.Close()
 
-		resp := make(map[string]interface{})
-		resp["code"] = 200
-		resp["msg"] = "提交成功!"
-		payload, err = json.Marshal(resp)
-		return req.Metadata, payload, err
+		return formValues, resp, err
 	} else {
 		req, err := HttpRequest2RpcxRequest(r)
 		if err != nil {
@@ -166,3 +178,27 @@ func getRequestToken(r *http.Request) string {
 	}
 	return ""
 }
+
+// sendFile
+func sendFile(conn net.Conn, file *multipart.FileHeader) error {
+	// 只读打开文件
+	f, err := file.Open()
+	if err != nil {
+		return err
+	}
+	// 从本文件中,读数据,写给网络接收端。
+	buf := make([]byte, 1024)
+	for {
+		n, err := f.Read(buf)
+		if n == 0 {
+			log.Debug("发送文件完成")
+			return nil
+		}
+		// 写到网络socket中
+		_, err = conn.Write(buf[:n])
+		if err != nil {
+			log.Debug("conn.Write err:", err)
+			return nil
+		}
+	}
+}