package gateway import ( "bytes" "errors" "io/ioutil" "mime/multipart" "net/http" "net/url" "strconv" "strings" "github.com/smallnest/rpcx/protocol" ) const ( XVersion = "X-RPCX-Version" XMessageType = "X-RPCX-MesssageType" XHeartbeat = "X-RPCX-Heartbeat" XOneway = "X-RPCX-Oneway" XMessageStatusType = "X-RPCX-MessageStatusType" XSerializeType = "X-RPCX-SerializeType" XMessageID = "X-RPCX-MessageID" XServicePath = "X-RPCX-ServicePath" XServiceMethod = "X-RPCX-ServiceMethod" XMeta = "X-RPCX-Meta" XErrorMessage = "X-RPCX-ErrorMessage" ) func HttpRequest2RpcxRequest(r *http.Request) (*protocol.Message, error) { req := protocol.NewMessage() req.SetMessageType(protocol.Request) h := r.Header seq := getRpcxHeader(r, XMessageID) if seq != "" { id, err := strconv.ParseUint(seq, 10, 64) if err != nil { return nil, err } req.SetSeq(id) } heartbeat := getRpcxHeader(r, XHeartbeat) if heartbeat != "" { req.SetHeartbeat(true) } oneway := getRpcxHeader(r, XOneway) if oneway != "" { req.SetOneway(true) } if h.Get("Content-Encoding") == "gzip" { req.SetCompressType(protocol.Gzip) } st := getRpcxHeader(r, XSerializeType) if st != "" { rst, err := strconv.Atoi(st) if err != nil { return nil, err } req.SetSerializeType(protocol.SerializeType(rst)) } else { return nil, errors.New("empty serialized type") } meta := getRpcxHeader(r, XMeta) if meta != "" { metadata, err := url.ParseQuery(meta) if err != nil { return nil, err } mm := make(map[string]string) for k, v := range metadata { if len(v) > 0 { mm[k] = v[0] } } req.Metadata = mm } req.Metadata = getUrlParams(r, req.Metadata) sp := getRpcxHeader(r, XServicePath) if sp != "" { req.ServicePath = sp } else { return nil, errors.New("empty servicepath") } sm := getRpcxHeader(r, XServiceMethod) if sm != "" { req.ServiceMethod = sm } else { return nil, errors.New("empty servicemethod") } payload, err := ioutil.ReadAll(r.Body) if err != nil { return nil, err } req.Payload = payload // Request.Body的读取, 读取数据时, 指针会对应移动至EOF, 所以下次读取的时候, seek指针还在EOF处 // 后续无法获取请求数据,在此读完又存 r.Body = ioutil.NopCloser(bytes.NewBuffer(payload)) return req, nil } func MultipartRequest2RpcxRequest(r *http.Request) (map[string]string, map[string][]*multipart.FileHeader, error) { r.ParseMultipartForm(10 << 20) //10mb form := r.MultipartForm formValues := make(map[string]string) //获取 multi-part/form header的 value h := r.Header meta := h.Get(XMeta) if meta != "" { metadata, err := url.ParseQuery(meta) if err != nil { return nil, nil, err } for k, v := range metadata { if len(v) > 0 { formValues[k] = v[0] } } } //获取 multi-part/form body中的form value for k, v := range form.Value { formValues[k] = v[0] } sp := h.Get(XServicePath) if sp != "" { formValues["reqService"] = sp } else { return nil, nil, errors.New("empty servicepath") } sm := h.Get(XServiceMethod) if sm != "" { formValues["reqMethod"] = sm } else { return nil, nil, errors.New("empty servicemethod") } formValues["authExclude"] = "false" formValues["fileNum"] = strconv.Itoa(len(form.File)) return formValues, form.File, nil } func getRpcxHeader(r *http.Request, key string) string { val := r.Header.Get(key) if val != "" { return val } else { if values, ok := r.URL.Query()[key]; ok && len(values) > 0 { return values[0] } else { return "" } } } func getUrlParams(r *http.Request, metadata map[string]string) map[string]string { if len(metadata) == 0 { metadata = make(map[string]string) } query := r.URL.Query() for k, v := range query { if !strings.HasPrefix(k, "X-RPCX-") && len(v) > 0 { metadata[k] = v[0] } } return metadata } // func RpcxResponse2HttpResponse(res *protocol.Message) (url.Values, []byte, error) { // m := make(url.Values) // m.Set(XVersion, strconv.Itoa(int(res.Version()))) // if res.IsHeartbeat() { // m.Set(XHeartbeat, "true") // } // if res.IsOneway() { // m.Set(XOneway, "true") // } // if res.MessageStatusType() == protocol.Error { // m.Set(XMessageStatusType, "Error") // } else { // m.Set(XMessageStatusType, "Normal") // } // if res.CompressType() == protocol.Gzip { // m.Set("Content-Encoding", "gzip") // } // m.Set(XSerializeType, strconv.Itoa(int(res.SerializeType()))) // m.Set(XMessageID, strconv.FormatUint(res.Seq(), 10)) // m.Set(XServicePath, res.ServicePath) // m.Set(XServiceMethod, res.ServiceMethod) // return m, res.Payload, nil // }