diff -Nru golang-github-valyala-fasthttp-20160617/allocation_test.go golang-github-valyala-fasthttp-1.31.0/allocation_test.go --- golang-github-valyala-fasthttp-20160617/allocation_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/allocation_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,85 @@ +//go:build !race +// +build !race + +package fasthttp + +import ( + "net" + + "testing" +) + +func TestAllocationServeConn(t *testing.T) { + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + + rw := &readWriter{} + // Make space for the request and response here so it + // doesn't allocate within the test. + rw.r.Grow(1024) + rw.w.Grow(1024) + + n := testing.AllocsPerRun(100, func() { + rw.r.WriteString("GET / HTTP/1.1\r\nHost: google.com\r\nCookie: foo=bar\r\n\r\n") + if err := s.ServeConn(rw); err != nil { + t.Fatal(err) + } + + // Reset the write buffer to make space for the next response. + rw.w.Reset() + }) + + if n != 0 { + t.Fatalf("expected 0 allocations, got %f", n) + } +} + +func TestAllocationClient(t *testing.T) { + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("cannot listen: %s", err) + } + defer ln.Close() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + go s.Serve(ln) //nolint:errcheck + + c := &Client{} + url := "http://test:test@" + ln.Addr().String() + "/foo?bar=baz" + + n := testing.AllocsPerRun(100, func() { + req := AcquireRequest() + res := AcquireResponse() + + req.SetRequestURI(url) + if err := c.Do(req, res); err != nil { + t.Fatal(err) + } + + ReleaseRequest(req) + ReleaseResponse(res) + }) + + if n != 0 { + t.Fatalf("expected 0 allocations, got %f", n) + } +} + +func TestAllocationURI(t *testing.T) { + uri := []byte("http://username:password@hello.%e4%b8%96%e7%95%8c.com/some/path?foo=bar#test") + + n := testing.AllocsPerRun(100, func() { + u := AcquireURI() + u.Parse(nil, uri) //nolint:errcheck + ReleaseURI(u) + }) + + if n != 0 { + t.Fatalf("expected 0 allocations, got %f", n) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/args.go golang-github-valyala-fasthttp-1.31.0/args.go --- golang-github-valyala-fasthttp-20160617/args.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/args.go 2021-10-09 18:39:05.000000000 +0000 @@ -4,7 +4,15 @@ "bytes" "errors" "io" + "sort" "sync" + + "github.com/valyala/bytebufferpool" +) + +const ( + argsNoValue = true + argsHasValue = false ) // AcquireArgs returns an empty Args object from the pool. @@ -15,7 +23,7 @@ return argsPool.Get().(*Args) } -// ReleaseArgs returns the object acquired via AquireArgs to the pool. +// ReleaseArgs returns the object acquired via AcquireArgs to the pool. // // Do not access the released Args object, otherwise data races may occur. func ReleaseArgs(a *Args) { @@ -36,15 +44,16 @@ // // Args instance MUST NOT be used from concurrently running goroutines. type Args struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck args []argsKV buf []byte } type argsKV struct { - key []byte - value []byte + key []byte + value []byte + noValue bool } // Reset clears query args. @@ -101,20 +110,36 @@ // QueryString returns query string for the args. // -// The returned value is valid until the next call to Args methods. +// The returned value is valid until the Args is reused or released (ReleaseArgs). +// Do not store references to the returned value. Make copies instead. func (a *Args) QueryString() []byte { a.buf = a.AppendBytes(a.buf[:0]) return a.buf } +// Sort sorts Args by key and then value using 'f' as comparison function. +// +// For example args.Sort(bytes.Compare) +func (a *Args) Sort(f func(x, y []byte) int) { + sort.SliceStable(a.args, func(i, j int) bool { + n := f(a.args[i].key, a.args[j].key) + if n == 0 { + return f(a.args[i].value, a.args[j].value) == -1 + } + return n == -1 + }) +} + // AppendBytes appends query string to dst and returns the extended dst. func (a *Args) AppendBytes(dst []byte) []byte { for i, n := 0, len(a.args); i < n; i++ { kv := &a.args[i] dst = AppendQuotedArg(dst, kv.key) - if len(kv.value) > 0 { + if !kv.noValue { dst = append(dst, '=') - dst = AppendQuotedArg(dst, kv.value) + if len(kv.value) > 0 { + dst = AppendQuotedArg(dst, kv.value) + } } if i+1 < n { dst = append(dst, '&') @@ -145,60 +170,88 @@ // // Multiple values for the same key may be added. func (a *Args) Add(key, value string) { - a.args = appendArg(a.args, key, value) + a.args = appendArg(a.args, key, value, argsHasValue) } // AddBytesK adds 'key=value' argument. // // Multiple values for the same key may be added. func (a *Args) AddBytesK(key []byte, value string) { - a.args = appendArg(a.args, b2s(key), value) + a.args = appendArg(a.args, b2s(key), value, argsHasValue) } // AddBytesV adds 'key=value' argument. // // Multiple values for the same key may be added. func (a *Args) AddBytesV(key string, value []byte) { - a.args = appendArg(a.args, key, b2s(value)) + a.args = appendArg(a.args, key, b2s(value), argsHasValue) } // AddBytesKV adds 'key=value' argument. // // Multiple values for the same key may be added. func (a *Args) AddBytesKV(key, value []byte) { - a.args = appendArg(a.args, b2s(key), b2s(value)) + a.args = appendArg(a.args, b2s(key), b2s(value), argsHasValue) +} + +// AddNoValue adds only 'key' as argument without the '='. +// +// Multiple values for the same key may be added. +func (a *Args) AddNoValue(key string) { + a.args = appendArg(a.args, key, "", argsNoValue) +} + +// AddBytesKNoValue adds only 'key' as argument without the '='. +// +// Multiple values for the same key may be added. +func (a *Args) AddBytesKNoValue(key []byte) { + a.args = appendArg(a.args, b2s(key), "", argsNoValue) } // Set sets 'key=value' argument. func (a *Args) Set(key, value string) { - a.args = setArg(a.args, key, value) + a.args = setArg(a.args, key, value, argsHasValue) } // SetBytesK sets 'key=value' argument. func (a *Args) SetBytesK(key []byte, value string) { - a.args = setArg(a.args, b2s(key), value) + a.args = setArg(a.args, b2s(key), value, argsHasValue) } // SetBytesV sets 'key=value' argument. func (a *Args) SetBytesV(key string, value []byte) { - a.args = setArg(a.args, key, b2s(value)) + a.args = setArg(a.args, key, b2s(value), argsHasValue) } // SetBytesKV sets 'key=value' argument. func (a *Args) SetBytesKV(key, value []byte) { - a.args = setArgBytes(a.args, key, value) + a.args = setArgBytes(a.args, key, value, argsHasValue) +} + +// SetNoValue sets only 'key' as argument without the '='. +// +// Only key in argumemt, like key1&key2 +func (a *Args) SetNoValue(key string) { + a.args = setArg(a.args, key, "", argsNoValue) +} + +// SetBytesKNoValue sets 'key' argument. +func (a *Args) SetBytesKNoValue(key []byte) { + a.args = setArg(a.args, b2s(key), "", argsNoValue) } // Peek returns query arg value for the given key. // -// Returned value is valid until the next Args call. +// The returned value is valid until the Args is reused or released (ReleaseArgs). +// Do not store references to the returned value. Make copies instead. func (a *Args) Peek(key string) []byte { return peekArgStr(a.args, key) } // PeekBytes returns query arg value for the given key. // -// Returned value is valid until the next Args call. +// The returned value is valid until the Args is reused or released (ReleaseArgs). +// Do not store references to the returned value. Make copies instead. func (a *Args) PeekBytes(key []byte) []byte { return peekArgBytes(a.args, key) } @@ -243,10 +296,10 @@ // SetUint sets uint value for the given key. func (a *Args) SetUint(key string, value int) { - bb := AcquireByteBuffer() + bb := bytebufferpool.Get() bb.B = AppendUint(bb.B[:0], value) a.SetBytesV(key, bb.B) - ReleaseByteBuffer(bb) + bytebufferpool.Put(bb) } // SetUintBytes sets uint value for the given key. @@ -285,6 +338,22 @@ return f } +// GetBool returns boolean value for the given key. +// +// true is returned for "1", "t", "T", "true", "TRUE", "True", "y", "yes", "Y", "YES", "Yes", +// otherwise false is returned. +func (a *Args) GetBool(key string) bool { + switch b2s(a.Peek(key)) { + // Support the same true cases as strconv.ParseBool + // See: https://github.com/golang/go/blob/4e1b11e2c9bdb0ddea1141eed487be1a626ff5be/src/strconv/atob.go#L12 + // and Y and Yes versions. + case "1", "t", "T", "true", "TRUE", "True", "y", "yes", "Y", "YES", "Yes": + return true + default: + return false + } +} + func visitArgs(args []argsKV, f func(k, v []byte)) { for i, n := 0, len(args); i < n; i++ { kv := &args[i] @@ -295,7 +364,13 @@ func copyArgs(dst, src []argsKV) []argsKV { if cap(dst) < len(src) { tmp := make([]argsKV, len(src)) + dst = dst[:cap(dst)] // copy all of dst. copy(tmp, dst) + for i := len(dst); i < len(tmp); i++ { + // Make sure nothing is nil. + tmp[i].key = []byte{} + tmp[i].value = []byte{} + } dst = tmp } n := len(src) @@ -304,7 +379,12 @@ dstKV := &dst[i] srcKV := &src[i] dstKV.key = append(dstKV.key[:0], srcKV.key...) - dstKV.value = append(dstKV.value[:0], srcKV.value...) + if srcKV.noValue { + dstKV.value = dstKV.value[:0] + } else { + dstKV.value = append(dstKV.value[:0], srcKV.value...) + } + dstKV.noValue = srcKV.noValue } return dst } @@ -320,6 +400,7 @@ tmp := *kv copy(args[i:], args[i+1:]) n-- + i-- args[n] = tmp args = args[:n] } @@ -327,31 +408,41 @@ return args } -func setArgBytes(h []argsKV, key, value []byte) []argsKV { - return setArg(h, b2s(key), b2s(value)) +func setArgBytes(h []argsKV, key, value []byte, noValue bool) []argsKV { + return setArg(h, b2s(key), b2s(value), noValue) } -func setArg(h []argsKV, key, value string) []argsKV { +func setArg(h []argsKV, key, value string, noValue bool) []argsKV { n := len(h) for i := 0; i < n; i++ { kv := &h[i] if key == string(kv.key) { - kv.value = append(kv.value[:0], value...) + if noValue { + kv.value = kv.value[:0] + } else { + kv.value = append(kv.value[:0], value...) + } + kv.noValue = noValue return h } } - return appendArg(h, key, value) + return appendArg(h, key, value, noValue) } -func appendArgBytes(h []argsKV, key, value []byte) []argsKV { - return appendArg(h, b2s(key), b2s(value)) +func appendArgBytes(h []argsKV, key, value []byte, noValue bool) []argsKV { + return appendArg(h, b2s(key), b2s(value), noValue) } -func appendArg(args []argsKV, key, value string) []argsKV { +func appendArg(args []argsKV, key, value string, noValue bool) []argsKV { var kv *argsKV args, kv = allocArg(args) kv.key = append(kv.key[:0], key...) - kv.value = append(kv.value[:0], value...) + if noValue { + kv.value = kv.value[:0] + } else { + kv.value = append(kv.value[:0], value...) + } + kv.noValue = noValue return args } @@ -360,7 +451,9 @@ if cap(h) > n { h = h[:n+1] } else { - h = append(h, argsKV{}) + h = append(h, argsKV{ + value: []byte{}, + }) } return h, &h[n] } @@ -407,6 +500,7 @@ if len(s.b) == 0 { return false } + kv.noValue = argsHasValue isKey := true k := 0 @@ -415,15 +509,16 @@ case '=': if isKey { isKey = false - kv.key = decodeArg(kv.key, s.b[:i], true) + kv.key = decodeArgAppend(kv.key[:0], s.b[:i]) k = i + 1 } case '&': if isKey { - kv.key = decodeArg(kv.key, s.b[:i], true) + kv.key = decodeArgAppend(kv.key[:0], s.b[:i]) kv.value = kv.value[:0] + kv.noValue = argsNoValue } else { - kv.value = decodeArg(kv.value, s.b[k:i], true) + kv.value = decodeArgAppend(kv.value[:0], s.b[k:i]) } s.b = s.b[i+1:] return true @@ -431,39 +526,75 @@ } if isKey { - kv.key = decodeArg(kv.key, s.b, true) + kv.key = decodeArgAppend(kv.key[:0], s.b) kv.value = kv.value[:0] + kv.noValue = argsNoValue } else { - kv.value = decodeArg(kv.value, s.b[k:], true) + kv.value = decodeArgAppend(kv.value[:0], s.b[k:]) } s.b = s.b[len(s.b):] return true } -func decodeArg(dst, src []byte, decodePlus bool) []byte { - return decodeArgAppend(dst[:0], src, decodePlus) -} +func decodeArgAppend(dst, src []byte) []byte { + if bytes.IndexByte(src, '%') < 0 && bytes.IndexByte(src, '+') < 0 { + // fast path: src doesn't contain encoded chars + return append(dst, src...) + } -func decodeArgAppend(dst, src []byte, decodePlus bool) []byte { - for i, n := 0, len(src); i < n; i++ { + // slow path + for i := 0; i < len(src); i++ { c := src[i] if c == '%' { - if i+2 >= n { + if i+2 >= len(src) { return append(dst, src[i:]...) } - x1 := hexbyte2int(src[i+1]) - x2 := hexbyte2int(src[i+2]) - if x1 < 0 || x2 < 0 { - dst = append(dst, c) + x2 := hex2intTable[src[i+2]] + x1 := hex2intTable[src[i+1]] + if x1 == 16 || x2 == 16 { + dst = append(dst, '%') } else { - dst = append(dst, byte(x1<<4|x2)) + dst = append(dst, x1<<4|x2) i += 2 } - } else if decodePlus && c == '+' { + } else if c == '+' { dst = append(dst, ' ') } else { dst = append(dst, c) } + } + return dst +} + +// decodeArgAppendNoPlus is almost identical to decodeArgAppend, but it doesn't +// substitute '+' with ' '. +// +// The function is copy-pasted from decodeArgAppend due to the performance +// reasons only. +func decodeArgAppendNoPlus(dst, src []byte) []byte { + if bytes.IndexByte(src, '%') < 0 { + // fast path: src doesn't contain encoded chars + return append(dst, src...) + } + + // slow path + for i := 0; i < len(src); i++ { + c := src[i] + if c == '%' { + if i+2 >= len(src) { + return append(dst, src[i:]...) + } + x2 := hex2intTable[src[i+2]] + x1 := hex2intTable[src[i+1]] + if x1 == 16 || x2 == 16 { + dst = append(dst, '%') + } else { + dst = append(dst, x1<<4|x2) + i += 2 + } + } else { + dst = append(dst, c) + } } return dst } diff -Nru golang-github-valyala-fasthttp-20160617/args_test.go golang-github-valyala-fasthttp-1.31.0/args_test.go --- golang-github-valyala-fasthttp-20160617/args_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/args_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,35 +1,73 @@ package fasthttp import ( + "bytes" "fmt" + "net/url" "reflect" "strings" "testing" "time" + + "github.com/valyala/bytebufferpool" ) +func TestDecodeArgAppend(t *testing.T) { + t.Parallel() + + testDecodeArgAppend(t, "", "") + testDecodeArgAppend(t, "foobar", "foobar") + testDecodeArgAppend(t, "тест", "тест") + testDecodeArgAppend(t, "a%", "a%") + testDecodeArgAppend(t, "%a%21", "%a!") + testDecodeArgAppend(t, "ab%test", "ab%test") + testDecodeArgAppend(t, "d%тестF", "d%тестF") + testDecodeArgAppend(t, "a%\xffb%20c", "a%\xffb c") + testDecodeArgAppend(t, "foo%20bar", "foo bar") + testDecodeArgAppend(t, "f.o%2C1%3A2%2F4=%7E%60%21%40%23%24%25%5E%26*%28%29_-%3D%2B%5C%7C%2F%5B%5D%7B%7D%3B%3A%27%22%3C%3E%2C.%2F%3F", + "f.o,1:2/4=~`!@#$%^&*()_-=+\\|/[]{};:'\"<>,./?") +} + +func testDecodeArgAppend(t *testing.T, s, expectedResult string) { + result := decodeArgAppend(nil, []byte(s)) + if string(result) != expectedResult { + t.Fatalf("unexpected decodeArgAppend(%q)=%q; expecting %q", s, result, expectedResult) + } +} + func TestArgsAdd(t *testing.T) { + t.Parallel() + var a Args a.Add("foo", "bar") a.Add("foo", "baz") a.Add("foo", "1") a.Add("ba", "23") - if a.Len() != 4 { - t.Fatalf("unexpected number of elements: %d. Expecting 4", a.Len()) + a.Add("foo", "") + a.AddNoValue("foo") + if a.Len() != 6 { + t.Fatalf("unexpected number of elements: %d. Expecting 6", a.Len()) } s := a.String() - expectedS := "foo=bar&foo=baz&foo=1&ba=23" + expectedS := "foo=bar&foo=baz&foo=1&ba=23&foo=&foo" if s != expectedS { t.Fatalf("unexpected result: %q. Expecting %q", s, expectedS) } + a.Sort(bytes.Compare) + ss := a.String() + expectedSS := "ba=23&foo=&foo&foo=1&foo=bar&foo=baz" + if ss != expectedSS { + t.Fatalf("unexpected result: %q. Expecting %q", ss, expectedSS) + } + var a1 Args a1.Parse(s) - if a1.Len() != 4 { - t.Fatalf("unexpected number of elements: %d. Expecting 4", a.Len()) + if a1.Len() != 6 { + t.Fatalf("unexpected number of elements: %d. Expecting 6", a.Len()) } - var barFound, bazFound, oneFound, baFound bool + var barFound, bazFound, oneFound, emptyFound1, emptyFound2, baFound bool a1.VisitAll(func(k, v []byte) { switch string(k) { case "foo": @@ -40,6 +78,12 @@ bazFound = true case "1": oneFound = true + case "": + if emptyFound1 { + emptyFound2 = true + } else { + emptyFound1 = true + } default: t.Fatalf("unexpected value %q", v) } @@ -52,8 +96,8 @@ t.Fatalf("unexpected key found %q", k) } }) - if !barFound || !bazFound || !oneFound || !baFound { - t.Fatalf("something is missing: %v, %v, %v, %v", barFound, bazFound, oneFound, baFound) + if !barFound || !bazFound || !oneFound || !emptyFound1 || !emptyFound2 || !baFound { + t.Fatalf("something is missing: %v, %v, %v, %v, %v, %v", barFound, bazFound, oneFound, emptyFound1, emptyFound2, baFound) } } @@ -104,6 +148,8 @@ } func TestArgsPeekMulti(t *testing.T) { + t.Parallel() + var a Args a.Parse("foo=123&bar=121&foo=321&foo=&barz=sdf") @@ -130,9 +176,19 @@ } func TestArgsEscape(t *testing.T) { + t.Parallel() + testArgsEscape(t, "foo", "bar", "foo=bar") - testArgsEscape(t, "f.o,1:2/4", "~`!@#$%^&*()_-=+\\|/[]{};:'\"<>,./?", - "f.o%2C1%3A2%2F4=%7E%60%21%40%23%24%25%5E%26*%28%29_-%3D%2B%5C%7C%2F%5B%5D%7B%7D%3B%3A%27%22%3C%3E%2C.%2F%3F") + + // Test all characters + k := "f.o,1:2/4" + var v = make([]byte, 256) + for i := 0; i < 256; i++ { + v[i] = byte(i) + } + u := url.Values{} + u.Add(k, string(v)) + testArgsEscape(t, k, string(v), u.Encode()) } func testArgsEscape(t *testing.T, k, v, expectedS string) { @@ -144,13 +200,41 @@ } } +func TestPathEscape(t *testing.T) { + t.Parallel() + + testPathEscape(t, "/foo/bar") + testPathEscape(t, "") + testPathEscape(t, "/") + testPathEscape(t, "//") + testPathEscape(t, "*") // See https://github.com/golang/go/issues/11202 + + // Test all characters + var pathSegment = make([]byte, 256) + for i := 0; i < 256; i++ { + pathSegment[i] = byte(i) + } + testPathEscape(t, "/foo/"+string(pathSegment)) +} + +func testPathEscape(t *testing.T, s string) { + u := url.URL{Path: s} + expectedS := u.EscapedPath() + res := string(appendQuotedPath(nil, []byte(s))) + if res != expectedS { + t.Fatalf("unexpected args %q. Expecting %q.", res, expectedS) + } +} + func TestArgsWriteTo(t *testing.T) { + t.Parallel() + s := "foo=bar&baz=123&aaa=bbb" var a Args a.Parse(s) - var w ByteBuffer + var w bytebufferpool.ByteBuffer n, err := a.WriteTo(&w) if err != nil { t.Fatalf("unexpected error: %s", err) @@ -164,7 +248,34 @@ } } +func TestArgsGetBool(t *testing.T) { + t.Parallel() + + testArgsGetBool(t, "", false) + testArgsGetBool(t, "0", false) + testArgsGetBool(t, "n", false) + testArgsGetBool(t, "no", false) + testArgsGetBool(t, "1", true) + testArgsGetBool(t, "y", true) + testArgsGetBool(t, "yes", true) + + testArgsGetBool(t, "123", false) + testArgsGetBool(t, "foobar", false) +} + +func testArgsGetBool(t *testing.T, value string, expectedResult bool) { + var a Args + a.Parse("v=" + value) + + result := a.GetBool("v") + if result != expectedResult { + t.Fatalf("unexpected result %v. Expecting %v for value %q", result, expectedResult, value) + } +} + func TestArgsUint(t *testing.T) { + t.Parallel() + var a Args a.SetUint("foo", 123) a.SetUint("bar", 0) @@ -198,6 +309,8 @@ } func TestArgsCopyTo(t *testing.T) { + t.Parallel() + var a Args // empty args @@ -207,6 +320,7 @@ testCopyTo(t, &a) a.Set("xxx", "yyy") + a.AddNoValue("ba") testCopyTo(t, &a) a.Del("foo") @@ -222,6 +336,10 @@ var b Args a.CopyTo(&b) + if !reflect.DeepEqual(*a, b) { //nolint + t.Fatalf("ArgsCopyTo fail, a: \n%+v\nb: \n%+v\n", *a, b) //nolint + } + b.VisitAll(func(k, v []byte) { if _, ok := keys[string(k)]; !ok { t.Fatalf("unexpected key %q after copying from %q", k, a.String()) @@ -234,6 +352,8 @@ } func TestArgsVisitAll(t *testing.T) { + t.Parallel() + var a Args a.Set("foo", "bar") @@ -253,14 +373,18 @@ } func TestArgsStringCompose(t *testing.T) { + t.Parallel() + var a Args a.Set("foo", "bar") a.Set("aa", "bbb") a.Set("привет", "мир") + a.SetNoValue("bb") a.Set("", "xxxx") a.Set("cvx", "") + a.SetNoValue("novalue") - expectedS := "foo=bar&aa=bbb&%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82=%D0%BC%D0%B8%D1%80&=xxxx&cvx" + expectedS := "foo=bar&aa=bbb&%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82=%D0%BC%D0%B8%D1%80&bb&=xxxx&cvx=&novalue" s := a.String() if s != expectedS { t.Fatalf("Unexpected string %q. Exected %q", s, expectedS) @@ -268,6 +392,8 @@ } func TestArgsString(t *testing.T) { + t.Parallel() + var a Args testArgsString(t, &a, "") @@ -275,7 +401,7 @@ testArgsString(t, &a, "foo=bar") testArgsString(t, &a, "foo=bar&baz=sss") testArgsString(t, &a, "") - testArgsString(t, &a, "f%20o=x.x*-_8x%D0%BF%D1%80%D0%B8%D0%B2%D0%B5aaa&sdf=ss") + testArgsString(t, &a, "f+o=x.x%2A-_8x%D0%BF%D1%80%D0%B8%D0%B2%D0%B5aaa&sdf=ss") testArgsString(t, &a, "=asdfsdf") } @@ -288,6 +414,8 @@ } func TestArgsSetGetDel(t *testing.T) { + t.Parallel() + var a Args if len(a.Peek("foo")) > 0 { @@ -322,7 +450,7 @@ a.Parse("aaa=xxx&bb=aa") if string(a.Peek("foo0")) != "" { - t.Fatalf("Unepxected value %q", a.Peek("foo0")) + t.Fatalf("Unexpected value %q", a.Peek("foo0")) } if string(a.Peek("aaa")) != "xxx" { t.Fatalf("Unexpected value %q. Expected %q", a.Peek("aaa"), "xxx") @@ -353,6 +481,8 @@ } func TestArgsParse(t *testing.T) { + t.Parallel() + var a Args // empty args @@ -399,6 +529,8 @@ } func TestArgsHas(t *testing.T) { + t.Parallel() + var a Args // single arg @@ -451,3 +583,41 @@ } } } + +func TestArgsDeleteAll(t *testing.T) { + t.Parallel() + var a Args + a.Add("q1", "foo") + a.Add("q1", "bar") + a.Add("q1", "baz") + a.Add("q1", "quux") + a.Add("q2", "1234") + a.Del("q1") + if a.Len() != 1 || a.Has("q1") { + t.Fatalf("Expected q1 arg to be completely deleted. Current Args: %s", a.String()) + } +} + +func TestIssue932(t *testing.T) { + t.Parallel() + var a []argsKV + + a = setArg(a, "t1", "ok", argsHasValue) + a = setArg(a, "t2", "", argsHasValue) + a = setArg(a, "t1", "", argsHasValue) + a = setArgBytes(a, s2b("t3"), []byte{}, argsHasValue) + a = setArgBytes(a, s2b("t4"), nil, argsHasValue) + + if peekArgStr(a, "t1") == nil { + t.Error("nil not expected for t1") + } + if peekArgStr(a, "t2") == nil { + t.Error("nil not expected for t2") + } + if peekArgStr(a, "t3") == nil { + t.Error("nil not expected for t3") + } + if peekArgStr(a, "t4") != nil { + t.Error("nil expected for t4") + } +} diff -Nru golang-github-valyala-fasthttp-20160617/brotli.go golang-github-valyala-fasthttp-1.31.0/brotli.go --- golang-github-valyala-fasthttp-20160617/brotli.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/brotli.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,193 @@ +package fasthttp + +import ( + "bytes" + "fmt" + "io" + "sync" + + "github.com/andybalholm/brotli" + "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp/stackless" +) + +// Supported compression levels. +const ( + CompressBrotliNoCompression = 0 + CompressBrotliBestSpeed = brotli.BestSpeed + CompressBrotliBestCompression = brotli.BestCompression + + // Choose a default brotli compression level comparable to + // CompressDefaultCompression (gzip 6) + // See: https://github.com/valyala/fasthttp/issues/798#issuecomment-626293806 + CompressBrotliDefaultCompression = 4 +) + +func acquireBrotliReader(r io.Reader) (*brotli.Reader, error) { + v := brotliReaderPool.Get() + if v == nil { + return brotli.NewReader(r), nil + } + zr := v.(*brotli.Reader) + if err := zr.Reset(r); err != nil { + return nil, err + } + return zr, nil +} + +func releaseBrotliReader(zr *brotli.Reader) { + brotliReaderPool.Put(zr) +} + +var brotliReaderPool sync.Pool + +func acquireStacklessBrotliWriter(w io.Writer, level int) stackless.Writer { + nLevel := normalizeBrotliCompressLevel(level) + p := stacklessBrotliWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + return stackless.NewWriter(w, func(w io.Writer) stackless.Writer { + return acquireRealBrotliWriter(w, level) + }) + } + sw := v.(stackless.Writer) + sw.Reset(w) + return sw +} + +func releaseStacklessBrotliWriter(sw stackless.Writer, level int) { + sw.Close() + nLevel := normalizeBrotliCompressLevel(level) + p := stacklessBrotliWriterPoolMap[nLevel] + p.Put(sw) +} + +func acquireRealBrotliWriter(w io.Writer, level int) *brotli.Writer { + nLevel := normalizeBrotliCompressLevel(level) + p := realBrotliWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + zw := brotli.NewWriterLevel(w, level) + return zw + } + zw := v.(*brotli.Writer) + zw.Reset(w) + return zw +} + +func releaseRealBrotliWriter(zw *brotli.Writer, level int) { + zw.Close() + nLevel := normalizeBrotliCompressLevel(level) + p := realBrotliWriterPoolMap[nLevel] + p.Put(zw) +} + +var ( + stacklessBrotliWriterPoolMap = newCompressWriterPoolMap() + realBrotliWriterPoolMap = newCompressWriterPoolMap() +) + +// AppendBrotliBytesLevel appends brotlied src to dst using the given +// compression level and returns the resulting dst. +// +// Supported compression levels are: +// +// * CompressBrotliNoCompression +// * CompressBrotliBestSpeed +// * CompressBrotliBestCompression +// * CompressBrotliDefaultCompression +func AppendBrotliBytesLevel(dst, src []byte, level int) []byte { + w := &byteSliceWriter{dst} + WriteBrotliLevel(w, src, level) //nolint:errcheck + return w.b +} + +// WriteBrotliLevel writes brotlied p to w using the given compression level +// and returns the number of compressed bytes written to w. +// +// Supported compression levels are: +// +// * CompressBrotliNoCompression +// * CompressBrotliBestSpeed +// * CompressBrotliBestCompression +// * CompressBrotliDefaultCompression +func WriteBrotliLevel(w io.Writer, p []byte, level int) (int, error) { + switch w.(type) { + case *byteSliceWriter, + *bytes.Buffer, + *bytebufferpool.ByteBuffer: + // These writers don't block, so we can just use stacklessWriteBrotli + ctx := &compressCtx{ + w: w, + p: p, + level: level, + } + stacklessWriteBrotli(ctx) + return len(p), nil + default: + zw := acquireStacklessBrotliWriter(w, level) + n, err := zw.Write(p) + releaseStacklessBrotliWriter(zw, level) + return n, err + } +} + +var stacklessWriteBrotli = stackless.NewFunc(nonblockingWriteBrotli) + +func nonblockingWriteBrotli(ctxv interface{}) { + ctx := ctxv.(*compressCtx) + zw := acquireRealBrotliWriter(ctx.w, ctx.level) + + _, err := zw.Write(ctx.p) + if err != nil { + panic(fmt.Sprintf("BUG: brotli.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err)) + } + + releaseRealBrotliWriter(zw, ctx.level) +} + +// WriteBrotli writes brotlied p to w and returns the number of compressed +// bytes written to w. +func WriteBrotli(w io.Writer, p []byte) (int, error) { + return WriteBrotliLevel(w, p, CompressBrotliDefaultCompression) +} + +// AppendBrotliBytes appends brotlied src to dst and returns the resulting dst. +func AppendBrotliBytes(dst, src []byte) []byte { + return AppendBrotliBytesLevel(dst, src, CompressBrotliDefaultCompression) +} + +// WriteUnbrotli writes unbrotlied p to w and returns the number of uncompressed +// bytes written to w. +func WriteUnbrotli(w io.Writer, p []byte) (int, error) { + r := &byteSliceReader{p} + zr, err := acquireBrotliReader(r) + if err != nil { + return 0, err + } + n, err := copyZeroAlloc(w, zr) + releaseBrotliReader(zr) + nn := int(n) + if int64(nn) != n { + return 0, fmt.Errorf("too much data unbrotlied: %d", n) + } + return nn, err +} + +// AppendUnbrotliBytes appends unbrotlied src to dst and returns the resulting dst. +func AppendUnbrotliBytes(dst, src []byte) ([]byte, error) { + w := &byteSliceWriter{dst} + _, err := WriteUnbrotli(w, src) + return w.b, err +} + +// normalizes compression level into [0..11], so it could be used as an index +// in *PoolMap. +func normalizeBrotliCompressLevel(level int) int { + // -2 is the lowest compression level - CompressHuffmanOnly + // 9 is the highest compression level - CompressBestCompression + if level < 0 || level > 11 { + level = CompressBrotliDefaultCompression + } + return level +} diff -Nru golang-github-valyala-fasthttp-20160617/brotli_test.go golang-github-valyala-fasthttp-1.31.0/brotli_test.go --- golang-github-valyala-fasthttp-20160617/brotli_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/brotli_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,177 @@ +package fasthttp + +import ( + "bufio" + "bytes" + "fmt" + "io/ioutil" + "testing" +) + +func TestBrotliBytesSerial(t *testing.T) { + t.Parallel() + + if err := testBrotliBytes(); err != nil { + t.Fatal(err) + } +} + +func TestBrotliBytesConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testBrotliBytes); err != nil { + t.Fatal(err) + } +} + +func testBrotliBytes() error { + for _, s := range compressTestcases { + if err := testBrotliBytesSingleCase(s); err != nil { + return err + } + } + return nil +} + +func testBrotliBytesSingleCase(s string) error { + prefix := []byte("foobar") + brotlipedS := AppendBrotliBytes(prefix, []byte(s)) + if !bytes.Equal(brotlipedS[:len(prefix)], prefix) { + return fmt.Errorf("unexpected prefix when compressing %q: %q. Expecting %q", s, brotlipedS[:len(prefix)], prefix) + } + + unbrotliedS, err := AppendUnbrotliBytes(prefix, brotlipedS[len(prefix):]) + if err != nil { + return fmt.Errorf("unexpected error when uncompressing %q: %s", s, err) + } + if !bytes.Equal(unbrotliedS[:len(prefix)], prefix) { + return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, unbrotliedS[:len(prefix)], prefix) + } + unbrotliedS = unbrotliedS[len(prefix):] + if string(unbrotliedS) != s { + return fmt.Errorf("unexpected uncompressed string %q. Expecting %q", unbrotliedS, s) + } + return nil +} + +func TestBrotliCompressSerial(t *testing.T) { + t.Parallel() + + if err := testBrotliCompress(); err != nil { + t.Fatal(err) + } +} + +func TestBrotliCompressConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testBrotliCompress); err != nil { + t.Fatal(err) + } +} + +func testBrotliCompress() error { + for _, s := range compressTestcases { + if err := testBrotliCompressSingleCase(s); err != nil { + return err + } + } + return nil +} + +func testBrotliCompressSingleCase(s string) error { + var buf bytes.Buffer + zw := acquireStacklessBrotliWriter(&buf, CompressDefaultCompression) + if _, err := zw.Write([]byte(s)); err != nil { + return fmt.Errorf("unexpected error: %s. s=%q", err, s) + } + releaseStacklessBrotliWriter(zw, CompressDefaultCompression) + + zr, err := acquireBrotliReader(&buf) + if err != nil { + return fmt.Errorf("unexpected error: %s. s=%q", err, s) + } + body, err := ioutil.ReadAll(zr) + if err != nil { + return fmt.Errorf("unexpected error: %s. s=%q", err, s) + } + if string(body) != s { + return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s) + } + releaseBrotliReader(zr) + return nil +} + +func TestCompressHandlerBrotliLevel(t *testing.T) { + t.Parallel() + + expectedBody := string(createFixedBody(2e4)) + h := CompressHandlerBrotliLevel(func(ctx *RequestCtx) { + ctx.Write([]byte(expectedBody)) //nolint:errcheck + }, CompressBrotliDefaultCompression, CompressDefaultCompression) + + var ctx RequestCtx + var resp Response + + // verify uncompressed response + h(&ctx) + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce := resp.Header.Peek(HeaderContentEncoding) + if string(ce) != "" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "") + } + body := resp.Body() + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // verify gzip-compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") + + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce = resp.Header.Peek(HeaderContentEncoding) + if string(ce) != "gzip" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") + } + body, err := resp.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // verify brotli-compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc, br") + + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce = resp.Header.Peek(HeaderContentEncoding) + if string(ce) != "br" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "br") + } + body, err = resp.BodyUnbrotli() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/bytebuffer_example_test.go golang-github-valyala-fasthttp-1.31.0/bytebuffer_example_test.go --- golang-github-valyala-fasthttp-20160617/bytebuffer_example_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytebuffer_example_test.go 1970-01-01 00:00:00.000000000 +0000 @@ -1,29 +0,0 @@ -package fasthttp_test - -import ( - "fmt" - - "github.com/valyala/fasthttp" -) - -func ExampleByteBuffer() { - // This request handler sets 'Your-IP' response header - // to 'Your IP is '. It uses ByteBuffer for constructing response - // header value with zero memory allocations. - yourIPRequestHandler := func(ctx *fasthttp.RequestCtx) { - b := fasthttp.AcquireByteBuffer() - b.B = append(b.B, "Your IP is <"...) - b.B = fasthttp.AppendIPv4(b.B, ctx.RemoteIP()) - b.B = append(b.B, ">"...) - ctx.Response.Header.SetBytesV("Your-IP", b.B) - - fmt.Fprintf(ctx, "Check response headers - they must contain 'Your-IP: %s'", b.B) - - // It is safe to release byte buffer now, since it is - // no longer used. - fasthttp.ReleaseByteBuffer(b) - } - - // Start fasthttp server returning your ip in response headers. - fasthttp.ListenAndServe(":8080", yourIPRequestHandler) -} diff -Nru golang-github-valyala-fasthttp-20160617/bytebuffer.go golang-github-valyala-fasthttp-1.31.0/bytebuffer.go --- golang-github-valyala-fasthttp-20160617/bytebuffer.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytebuffer.go 1970-01-01 00:00:00.000000000 +0000 @@ -1,88 +0,0 @@ -package fasthttp - -import ( - "sync" -) - -const ( - defaultByteBufferSize = 128 -) - -// ByteBuffer provides byte buffer, which can be used with fasthttp API -// in order to minimize memory allocations. -// -// ByteBuffer may be used with functions appending data to the given []byte -// slice. See example code for details. -// -// Use AcquireByteBuffer for obtaining an empty byte buffer. -type ByteBuffer struct { - - // B is a byte buffer to use in append-like workloads. - // See example code for details. - B []byte -} - -// Write implements io.Writer - it appends p to ByteBuffer.B -func (b *ByteBuffer) Write(p []byte) (int, error) { - b.B = append(b.B, p...) - return len(p), nil -} - -// WriteString appends s to ByteBuffer.B -func (b *ByteBuffer) WriteString(s string) (int, error) { - b.B = append(b.B, s...) - return len(s), nil -} - -// Set sets ByteBuffer.B to p -func (b *ByteBuffer) Set(p []byte) { - b.B = append(b.B[:0], p...) -} - -// SetString sets ByteBuffer.B to s -func (b *ByteBuffer) SetString(s string) { - b.B = append(b.B[:0], s...) -} - -// Reset makes ByteBuffer.B empty. -func (b *ByteBuffer) Reset() { - b.B = b.B[:0] -} - -// AcquireByteBuffer returns an empty byte buffer from the pool. -// -// Acquired byte buffer may be returned to the pool via ReleaseByteBuffer call. -// This reduces the number of memory allocations required for byte buffer -// management. -func AcquireByteBuffer() *ByteBuffer { - return defaultByteBufferPool.Acquire() -} - -// ReleaseByteBuffer returns byte buffer to the pool. -// -// ByteBuffer.B mustn't be touched after returning it to the pool. -// Otherwise data races occur. -func ReleaseByteBuffer(b *ByteBuffer) { - defaultByteBufferPool.Release(b) -} - -type byteBufferPool struct { - pool sync.Pool -} - -var defaultByteBufferPool byteBufferPool - -func (p *byteBufferPool) Acquire() *ByteBuffer { - v := p.pool.Get() - if v == nil { - return &ByteBuffer{ - B: make([]byte, 0, defaultByteBufferSize), - } - } - return v.(*ByteBuffer) -} - -func (p *byteBufferPool) Release(b *ByteBuffer) { - b.B = b.B[:0] - p.pool.Put(b) -} diff -Nru golang-github-valyala-fasthttp-20160617/bytebuffer_test.go golang-github-valyala-fasthttp-1.31.0/bytebuffer_test.go --- golang-github-valyala-fasthttp-20160617/bytebuffer_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytebuffer_test.go 1970-01-01 00:00:00.000000000 +0000 @@ -1,43 +0,0 @@ -package fasthttp - -import ( - "fmt" - "testing" - "time" -) - -func TestByteBufferAcquireReleaseSerial(t *testing.T) { - testByteBufferAcquireRelease(t) -} - -func TestByteBufferAcquireReleaseConcurrent(t *testing.T) { - concurrency := 10 - ch := make(chan struct{}, concurrency) - for i := 0; i < concurrency; i++ { - go func() { - testByteBufferAcquireRelease(t) - ch <- struct{}{} - }() - } - - for i := 0; i < concurrency; i++ { - select { - case <-ch: - case <-time.After(time.Second): - t.Fatalf("timeout!") - } - } -} - -func testByteBufferAcquireRelease(t *testing.T) { - for i := 0; i < 10; i++ { - b := AcquireByteBuffer() - b.B = append(b.B, "num "...) - b.B = AppendUint(b.B, i) - expectedS := fmt.Sprintf("num %d", i) - if string(b.B) != expectedS { - t.Fatalf("unexpected result: %q. Expecting %q", b.B, expectedS) - } - ReleaseByteBuffer(b) - } -} diff -Nru golang-github-valyala-fasthttp-20160617/bytebuffer_timing_test.go golang-github-valyala-fasthttp-1.31.0/bytebuffer_timing_test.go --- golang-github-valyala-fasthttp-20160617/bytebuffer_timing_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytebuffer_timing_test.go 1970-01-01 00:00:00.000000000 +0000 @@ -1,32 +0,0 @@ -package fasthttp - -import ( - "bytes" - "testing" -) - -func BenchmarkByteBufferWrite(b *testing.B) { - s := []byte("foobarbaz") - b.RunParallel(func(pb *testing.PB) { - var buf ByteBuffer - for pb.Next() { - for i := 0; i < 100; i++ { - buf.Write(s) - } - buf.Reset() - } - }) -} - -func BenchmarkBytesBufferWrite(b *testing.B) { - s := []byte("foobarbaz") - b.RunParallel(func(pb *testing.PB) { - var buf bytes.Buffer - for pb.Next() { - for i := 0; i < 100; i++ { - buf.Write(s) - } - buf.Reset() - } - }) -} diff -Nru golang-github-valyala-fasthttp-20160617/bytesconv_32.go golang-github-valyala-fasthttp-1.31.0/bytesconv_32.go --- golang-github-valyala-fasthttp-20160617/bytesconv_32.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytesconv_32.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,8 +1,8 @@ -// +build !amd64,!arm64,!ppc64 +//go:build !amd64 && !arm64 && !ppc64 && !ppc64le +// +build !amd64,!arm64,!ppc64,!ppc64le package fasthttp const ( - maxIntChars = 9 maxHexIntChars = 7 ) diff -Nru golang-github-valyala-fasthttp-20160617/bytesconv_32_test.go golang-github-valyala-fasthttp-1.31.0/bytesconv_32_test.go --- golang-github-valyala-fasthttp-20160617/bytesconv_32_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytesconv_32_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,4 +1,5 @@ -// +build !amd64,!arm64,!ppc64 +//go:build !amd64 && !arm64 && !ppc64 && !ppc64le +// +build !amd64,!arm64,!ppc64,!ppc64le package fasthttp @@ -7,6 +8,8 @@ ) func TestWriteHexInt(t *testing.T) { + t.Parallel() + testWriteHexInt(t, 0, "0") testWriteHexInt(t, 1, "1") testWriteHexInt(t, 0x123, "123") @@ -14,6 +17,8 @@ } func TestAppendUint(t *testing.T) { + t.Parallel() + testAppendUint(t, 0) testAppendUint(t, 123) testAppendUint(t, 0x7fffffff) @@ -24,6 +29,8 @@ } func TestReadHexIntSuccess(t *testing.T) { + t.Parallel() + testReadHexIntSuccess(t, "0", 0) testReadHexIntSuccess(t, "fF", 0xff) testReadHexIntSuccess(t, "00abc", 0xabc) @@ -32,8 +39,22 @@ testReadHexIntSuccess(t, "1234ZZZ", 0x1234) } +func TestParseUintError32(t *testing.T) { + t.Parallel() + + // Overflow by last digit: 2 ** 32 / 2 * 10 ** n + testParseUintError(t, "2147483648") + testParseUintError(t, "21474836480") + testParseUintError(t, "214748364800") +} + func TestParseUintSuccess(t *testing.T) { + t.Parallel() + testParseUintSuccess(t, "0", 0) testParseUintSuccess(t, "123", 123) testParseUintSuccess(t, "123456789", 123456789) + + // Max supported value: 2 ** 32 / 2 - 1 + testParseUintSuccess(t, "2147483647", 2147483647) } diff -Nru golang-github-valyala-fasthttp-20160617/bytesconv_64.go golang-github-valyala-fasthttp-1.31.0/bytesconv_64.go --- golang-github-valyala-fasthttp-20160617/bytesconv_64.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytesconv_64.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,8 +1,8 @@ -// +build amd64 arm64 ppc64 +//go:build amd64 || arm64 || ppc64 || ppc64le +// +build amd64 arm64 ppc64 ppc64le package fasthttp const ( - maxIntChars = 18 maxHexIntChars = 15 ) diff -Nru golang-github-valyala-fasthttp-20160617/bytesconv_64_test.go golang-github-valyala-fasthttp-1.31.0/bytesconv_64_test.go --- golang-github-valyala-fasthttp-20160617/bytesconv_64_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytesconv_64_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,4 +1,5 @@ -// +build amd64 arm64 ppc64 +//go:build amd64 || arm64 || ppc64 || ppc64le +// +build amd64 arm64 ppc64 ppc64le package fasthttp @@ -7,6 +8,8 @@ ) func TestWriteHexInt(t *testing.T) { + t.Parallel() + testWriteHexInt(t, 0, "0") testWriteHexInt(t, 1, "1") testWriteHexInt(t, 0x123, "123") @@ -14,6 +17,8 @@ } func TestAppendUint(t *testing.T) { + t.Parallel() + testAppendUint(t, 0) testAppendUint(t, 123) testAppendUint(t, 0x7fffffffffffffff) @@ -24,6 +29,8 @@ } func TestReadHexIntSuccess(t *testing.T) { + t.Parallel() + testReadHexIntSuccess(t, "0", 0) testReadHexIntSuccess(t, "fF", 0xff) testReadHexIntSuccess(t, "00abc", 0xabc) @@ -33,9 +40,23 @@ testReadHexIntSuccess(t, "7ffffffffffffff", 0x7ffffffffffffff) } +func TestParseUintError64(t *testing.T) { + t.Parallel() + + // Overflow by last digit: 2 ** 64 / 2 * 10 ** n + testParseUintError(t, "9223372036854775808") + testParseUintError(t, "92233720368547758080") + testParseUintError(t, "922337203685477580800") +} + func TestParseUintSuccess(t *testing.T) { + t.Parallel() + testParseUintSuccess(t, "0", 0) testParseUintSuccess(t, "123", 123) testParseUintSuccess(t, "1234567890", 1234567890) testParseUintSuccess(t, "123456789012345678", 123456789012345678) + + // Max supported value: 2 ** 64 / 2 - 1 + testParseUintSuccess(t, "9223372036854775807", 9223372036854775807) } diff -Nru golang-github-valyala-fasthttp-20160617/bytesconv.go golang-github-valyala-fasthttp-1.31.0/bytesconv.go --- golang-github-valyala-fasthttp-20160617/bytesconv.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytesconv.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,3 +1,5 @@ +//go:generate go run bytesconv_table_gen.go + package fasthttp import ( @@ -9,6 +11,7 @@ "math" "net" "reflect" + "strings" "sync" "time" "unsafe" @@ -16,6 +19,16 @@ // AppendHTMLEscape appends html-escaped s to dst and returns the extended dst. func AppendHTMLEscape(dst []byte, s string) []byte { + if strings.IndexByte(s, '<') < 0 && + strings.IndexByte(s, '>') < 0 && + strings.IndexByte(s, '"') < 0 && + strings.IndexByte(s, '\'') < 0 { + + // fast path - nothing to escape + return append(dst, s...) + } + + // slow path var prev int var sub string for i, n := 0, len(s); i < n; i++ { @@ -153,7 +166,7 @@ var ( errEmptyInt = errors.New("empty integer") errUnexpectedFirstChar = errors.New("unexpected first char found. Expecting 0-9") - errUnexpectedTrailingChar = errors.New("unexpected traling char found. Expecting 0-9") + errUnexpectedTrailingChar = errors.New("unexpected trailing char found. Expecting 0-9") errTooLongInt = errors.New("too long int") ) @@ -172,10 +185,12 @@ } return v, i, nil } - if i >= maxIntChars { + vNew := 10*v + int(k) + // Test for overflow. + if vNew < v { return -1, i, errTooLongInt } - v = 10*v + int(k) + v = vNew } return v, n, nil } @@ -254,12 +269,14 @@ } return -1, err } - k = hexbyte2int(c) - if k < 0 { + k = int(hex2intTable[c]) + if k == 16 { if i == 0 { return -1, errEmptyHexNum } - r.UnreadByte() + if err := r.UnreadByte(); err != nil { + return -1, err + } return n, nil } if i >= maxHexIntChars { @@ -284,7 +301,7 @@ buf := v.([]byte) i := len(buf) - 1 for { - buf[i] = int2hexbyte(n & 0xf) + buf[i] = lowerhex[n&0xf] n >>= 4 if n == 0 { break @@ -296,59 +313,15 @@ return err } -func int2hexbyte(n int) byte { - if n < 10 { - return '0' + byte(n) - } - return 'a' + byte(n) - 10 -} - -func hexCharUpper(c byte) byte { - if c < 10 { - return '0' + c - } - return c - 10 + 'A' -} - -var hex2intTable = func() []byte { - b := make([]byte, 255) - for i := byte(0); i < 255; i++ { - c := byte(0) - if i >= '0' && i <= '9' { - c = 1 + i - '0' - } else if i >= 'a' && i <= 'f' { - c = 1 + i - 'a' + 10 - } else if i >= 'A' && i <= 'F' { - c = 1 + i - 'A' + 10 - } - b[i] = c - } - return b -}() - -func hexbyte2int(c byte) int { - return int(hex2intTable[c]) - 1 -} - -const toLower = 'a' - 'A' - -func uppercaseByte(p *byte) { - c := *p - if c >= 'a' && c <= 'z' { - *p = c - toLower - } -} - -func lowercaseByte(p *byte) { - c := *p - if c >= 'A' && c <= 'Z' { - *p = c + toLower - } -} +const ( + upperhex = "0123456789ABCDEF" + lowerhex = "0123456789abcdef" +) func lowercaseBytes(b []byte) { - for i, n := 0, len(b); i < n; i++ { - lowercaseByte(&b[i]) + for i := 0; i < len(b); i++ { + p := &b[i] + *p = toLowerTable[*p] } } @@ -358,6 +331,7 @@ // Note it may break if string and/or slice header will change // in the future go versions. func b2s(b []byte) string { + /* #nosec G103 */ return *(*string)(unsafe.Pointer(&b)) } @@ -365,58 +339,51 @@ // // Note it may break if string and/or slice header will change // in the future go versions. -func s2b(s string) []byte { +func s2b(s string) (b []byte) { + /* #nosec G103 */ + bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + /* #nosec G103 */ sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) - bh := reflect.SliceHeader{ - Data: sh.Data, - Len: sh.Len, - Cap: sh.Len, - } - return *(*[]byte)(unsafe.Pointer(&bh)) + bh.Data = sh.Data + bh.Cap = sh.Len + bh.Len = sh.Len + return b +} + +// AppendUnquotedArg appends url-decoded src to dst and returns appended dst. +// +// dst may point to src. In this case src will be overwritten. +func AppendUnquotedArg(dst, src []byte) []byte { + return decodeArgAppend(dst, src) } // AppendQuotedArg appends url-encoded src to dst and returns appended dst. func AppendQuotedArg(dst, src []byte) []byte { for _, c := range src { - // See http://www.w3.org/TR/html5/forms.html#form-submission-algorithm - if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || - c == '*' || c == '-' || c == '.' || c == '_' { + switch { + case c == ' ': + dst = append(dst, '+') + case quotedArgShouldEscapeTable[int(c)] != 0: + dst = append(dst, '%', upperhex[c>>4], upperhex[c&0xf]) + default: dst = append(dst, c) - } else { - dst = append(dst, '%', hexCharUpper(c>>4), hexCharUpper(c&15)) } } return dst } func appendQuotedPath(dst, src []byte) []byte { + // Fix issue in https://github.com/golang/go/issues/11202 + if len(src) == 1 && src[0] == '*' { + return append(dst, '*') + } + for _, c := range src { - if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || - c == '/' || c == '.' || c == ',' || c == '=' || c == ':' || c == '&' || c == '~' || c == '-' || c == '_' { - dst = append(dst, c) + if quotedPathShouldEscapeTable[int(c)] != 0 { + dst = append(dst, '%', upperhex[c>>4], upperhex[c&15]) } else { - dst = append(dst, '%', hexCharUpper(c>>4), hexCharUpper(c&15)) + dst = append(dst, c) } } return dst } - -// EqualBytesStr returns true if string(b) == s. -// -// This function has no performance benefits comparing to string(b) == s. -// It is left here for backwards compatibility only. -// -// This function is deperecated and may be deleted soon. -func EqualBytesStr(b []byte, s string) bool { - return string(b) == s -} - -// AppendBytesStr appends src to dst and returns the extended dst. -// -// This function has no performance benefits comparing to append(dst, src...). -// It is left here for backwards compatibility only. -// -// This function is deprecated and may be deleted soon. -func AppendBytesStr(dst []byte, src string) []byte { - return append(dst, src...) -} diff -Nru golang-github-valyala-fasthttp-20160617/bytesconv_table_gen.go golang-github-valyala-fasthttp-1.31.0/bytesconv_table_gen.go --- golang-github-valyala-fasthttp-20160617/bytesconv_table_gen.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytesconv_table_gen.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,120 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "bytes" + "fmt" + "io/ioutil" + "log" +) + +const ( + toLower = 'a' - 'A' +) + +func main() { + hex2intTable := func() [256]byte { + var b [256]byte + for i := 0; i < 256; i++ { + c := byte(16) + if i >= '0' && i <= '9' { + c = byte(i) - '0' + } else if i >= 'a' && i <= 'f' { + c = byte(i) - 'a' + 10 + } else if i >= 'A' && i <= 'F' { + c = byte(i) - 'A' + 10 + } + b[i] = c + } + return b + }() + + toLowerTable := func() [256]byte { + var a [256]byte + for i := 0; i < 256; i++ { + c := byte(i) + if c >= 'A' && c <= 'Z' { + c += toLower + } + a[i] = c + } + return a + }() + + toUpperTable := func() [256]byte { + var a [256]byte + for i := 0; i < 256; i++ { + c := byte(i) + if c >= 'a' && c <= 'z' { + c -= toLower + } + a[i] = c + } + return a + }() + + quotedArgShouldEscapeTable := func() [256]byte { + // According to RFC 3986 §2.3 + var a [256]byte + for i := 0; i < 256; i++ { + a[i] = 1 + } + + // ALPHA + for i := int('a'); i <= int('z'); i++ { + a[i] = 0 + } + for i := int('A'); i <= int('Z'); i++ { + a[i] = 0 + } + + // DIGIT + for i := int('0'); i <= int('9'); i++ { + a[i] = 0 + } + + // Unreserved characters + for _, v := range `-_.~` { + a[v] = 0 + } + + return a + }() + + quotedPathShouldEscapeTable := func() [256]byte { + // The implementation here equal to net/url shouldEscape(s, encodePath) + // + // The RFC allows : @ & = + $ but saves / ; , for assigning + // meaning to individual path segments. This package + // only manipulates the path as a whole, so we allow those + // last three as well. That leaves only ? to escape. + var a = quotedArgShouldEscapeTable + + for _, v := range `$&+,/:;=@` { + a[v] = 0 + } + + return a + }() + + w := new(bytes.Buffer) + w.WriteString(pre) + fmt.Fprintf(w, "const hex2intTable = %q\n", hex2intTable) + fmt.Fprintf(w, "const toLowerTable = %q\n", toLowerTable) + fmt.Fprintf(w, "const toUpperTable = %q\n", toUpperTable) + fmt.Fprintf(w, "const quotedArgShouldEscapeTable = %q\n", quotedArgShouldEscapeTable) + fmt.Fprintf(w, "const quotedPathShouldEscapeTable = %q\n", quotedPathShouldEscapeTable) + + if err := ioutil.WriteFile("bytesconv_table.go", w.Bytes(), 0660); err != nil { + log.Fatal(err) + } +} + +const pre = `package fasthttp + +// Code generated by go run bytesconv_table_gen.go; DO NOT EDIT. +// See bytesconv_table_gen.go for more information about these tables. + +` diff -Nru golang-github-valyala-fasthttp-20160617/bytesconv_table.go golang-github-valyala-fasthttp-1.31.0/bytesconv_table.go --- golang-github-valyala-fasthttp-20160617/bytesconv_table.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytesconv_table.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,10 @@ +package fasthttp + +// Code generated by go run bytesconv_table_gen.go; DO NOT EDIT. +// See bytesconv_table_gen.go for more information about these tables. + +const hex2intTable = "\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x00\x01\x02\x03\x04\x05\x06\a\b\t\x10\x10\x10\x10\x10\x10\x10\n\v\f\r\x0e\x0f\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\n\v\f\r\x0e\x0f\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10\x10" +const toLowerTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@abcdefghijklmnopqrstuvwxyz[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff" +const toUpperTable = "\x00\x01\x02\x03\x04\x05\x06\a\b\t\n\v\f\r\x0e\x0f\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`ABCDEFGHIJKLMNOPQRSTUVWXYZ{|}~\u007f\x80\x81\x82\x83\x84\x85\x86\x87\x88\x89\x8a\x8b\x8c\x8d\x8e\x8f\x90\x91\x92\x93\x94\x95\x96\x97\x98\x99\x9a\x9b\x9c\x9d\x9e\x9f\xa0\xa1\xa2\xa3\xa4\xa5\xa6\xa7\xa8\xa9\xaa\xab\xac\xad\xae\xaf\xb0\xb1\xb2\xb3\xb4\xb5\xb6\xb7\xb8\xb9\xba\xbb\xbc\xbd\xbe\xbf\xc0\xc1\xc2\xc3\xc4\xc5\xc6\xc7\xc8\xc9\xca\xcb\xcc\xcd\xce\xcf\xd0\xd1\xd2\xd3\xd4\xd5\xd6\xd7\xd8\xd9\xda\xdb\xdc\xdd\xde\xdf\xe0\xe1\xe2\xe3\xe4\xe5\xe6\xe7\xe8\xe9\xea\xeb\xec\xed\xee\xef\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7\xf8\xf9\xfa\xfb\xfc\xfd\xfe\xff" +const quotedArgShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" +const quotedPathShouldEscapeTable = "\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x00\x01\x00\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x01\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01\x01\x00\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" diff -Nru golang-github-valyala-fasthttp-20160617/bytesconv_test.go golang-github-valyala-fasthttp-1.31.0/bytesconv_test.go --- golang-github-valyala-fasthttp-20160617/bytesconv_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytesconv_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -7,9 +7,13 @@ "net" "testing" "time" + + "github.com/valyala/bytebufferpool" ) func TestAppendHTMLEscape(t *testing.T) { + t.Parallel() + testAppendHTMLEscape(t, "", "") testAppendHTMLEscape(t, "<", "<") testAppendHTMLEscape(t, "a", "a") @@ -25,6 +29,8 @@ } func TestParseIPv4(t *testing.T) { + t.Parallel() + testParseIPv4(t, "0.0.0.0", true) testParseIPv4(t, "255.255.255.255", true) testParseIPv4(t, "123.45.67.89", true) @@ -56,6 +62,8 @@ } func TestAppendIPv4(t *testing.T) { + t.Parallel() + testAppendIPv4(t, "0.0.0.0", true) testAppendIPv4(t, "127.0.0.1", true) testAppendIPv4(t, "8.8.8.8", true) @@ -73,7 +81,7 @@ s := string(AppendIPv4(nil, ip)) if isValid { if s != ipStr { - t.Fatalf("unepxected ip %q. Expecting %q", s, ipStr) + t.Fatalf("unexpected ip %q. Expecting %q", s, ipStr) } } else { ipStr = "non-v4 ip passed to AppendIPv4" @@ -92,7 +100,7 @@ } func testWriteHexInt(t *testing.T, n int, expectedS string) { - var w ByteBuffer + var w bytebufferpool.ByteBuffer bw := bufio.NewWriter(&w) if err := writeHexInt(bw, n); err != nil { t.Fatalf("unexpected error when writing hex %x: %s", n, err) @@ -107,6 +115,8 @@ } func TestReadHexIntError(t *testing.T) { + t.Parallel() + testReadHexIntError(t, "") testReadHexIntError(t, "ZZZ") testReadHexIntError(t, "-123") @@ -138,6 +148,8 @@ } func TestAppendHTTPDate(t *testing.T) { + t.Parallel() + d := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) s := string(AppendHTTPDate(nil, d)) expectedS := "Tue, 10 Nov 2009 23:00:00 GMT" @@ -157,6 +169,8 @@ } func TestParseUintError(t *testing.T) { + t.Parallel() + // empty string testParseUintError(t, "") @@ -174,20 +188,25 @@ // too big num testParseUintError(t, "12345678901234567890") + testParseUintError(t, "1234567890123456789012") } func TestParseUfloatSuccess(t *testing.T) { + t.Parallel() + testParseUfloatSuccess(t, "0", 0) testParseUfloatSuccess(t, "1.", 1.) testParseUfloatSuccess(t, ".1", 0.1) testParseUfloatSuccess(t, "123.456", 123.456) testParseUfloatSuccess(t, "123", 123) testParseUfloatSuccess(t, "1234e2", 1234e2) - testParseUfloatSuccess(t, "1234E-5", 1234E-5) + testParseUfloatSuccess(t, "1234E-5", 1234e-5) testParseUfloatSuccess(t, "1.234e+3", 1.234e+3) } func TestParseUfloatError(t *testing.T) { + t.Parallel() + // empty num testParseUfloatError(t, "") @@ -257,3 +276,46 @@ t.Fatalf("Unexpected value %d. Expected %d. num=%q", n, expectedN, s) } } + +func TestAppendUnquotedArg(t *testing.T) { + t.Parallel() + + testAppendUnquotedArg(t, "", "") + testAppendUnquotedArg(t, "abc", "abc") + testAppendUnquotedArg(t, "тест.abc", "тест.abc") + testAppendUnquotedArg(t, "%D1%82%D0%B5%D1%81%D1%82%20%=&;:", "тест %=&;:") +} + +func testAppendUnquotedArg(t *testing.T, s, expectedS string) { + // test appending to nil + result := AppendUnquotedArg(nil, []byte(s)) + if string(result) != expectedS { + t.Fatalf("Unexpected AppendUnquotedArg(%q)=%q, want %q", s, result, expectedS) + } + + // test appending to prefix + prefix := "prefix" + dst := []byte(prefix) + dst = AppendUnquotedArg(dst, []byte(s)) + if !bytes.HasPrefix(dst, []byte(prefix)) { + t.Fatalf("Unexpected prefix for AppendUnquotedArg(%q)=%q, want %q", s, dst, prefix) + } + result = dst[len(prefix):] + if string(result) != expectedS { + t.Fatalf("Unexpected AppendUnquotedArg(%q)=%q, want %q", s, result, expectedS) + } + + // test in-place appending + result = []byte(s) + result = AppendUnquotedArg(result[:0], result) + if string(result) != expectedS { + t.Fatalf("Unexpected AppendUnquotedArg(%q)=%q, want %q", s, result, expectedS) + } + + // verify AppendQuotedArg <-> AppendUnquotedArg conversion + quotedS := AppendQuotedArg(nil, []byte(s)) + unquotedS := AppendUnquotedArg(nil, quotedS) + if s != string(unquotedS) { + t.Fatalf("Unexpected AppendUnquotedArg(AppendQuotedArg(%q))=%q, want %q", s, unquotedS, s) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/bytesconv_timing_test.go golang-github-valyala-fasthttp-1.31.0/bytesconv_timing_test.go --- golang-github-valyala-fasthttp-20160617/bytesconv_timing_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/bytesconv_timing_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -5,6 +5,8 @@ "html" "net" "testing" + + "github.com/valyala/bytebufferpool" ) func BenchmarkAppendHTMLEscape(b *testing.B) { @@ -63,37 +65,13 @@ }) } -func BenchmarkInt2HexByte(b *testing.B) { - buf := []int{1, 0xf, 2, 0xd, 3, 0xe, 4, 0xa, 5, 0xb, 6, 0xc, 7, 0xf, 0, 0xf, 6, 0xd, 9, 8, 4, 0x5} - b.RunParallel(func(pb *testing.PB) { - var n int - for pb.Next() { - for _, n = range buf { - int2hexbyte(n) - } - } - }) -} - -func BenchmarkHexByte2Int(b *testing.B) { - buf := []byte("0A1B2c3d4E5F6C7a8D9ab7cd03ef") - b.RunParallel(func(pb *testing.PB) { - var c byte - for pb.Next() { - for _, c = range buf { - hexbyte2int(c) - } - } - }) -} - func BenchmarkWriteHexInt(b *testing.B) { b.RunParallel(func(pb *testing.PB) { - var w ByteBuffer + var w bytebufferpool.ByteBuffer bw := bufio.NewWriter(&w) i := 0 for pb.Next() { - writeHexInt(bw, i) + writeHexInt(bw, i) //nolint:errcheck i++ if i > 0x7fffffff { i = 0 @@ -165,3 +143,23 @@ } }) } + +func BenchmarkAppendUnquotedArgFastPath(b *testing.B) { + src := []byte("foobarbaz no quoted chars fdskjsdf jklsdfdfskljd;aflskjdsaf fdsklj fsdkj fsdl kfjsdlk jfsdklj fsdfsdf sdfkflsd") + b.RunParallel(func(pb *testing.PB) { + var dst []byte + for pb.Next() { + dst = AppendUnquotedArg(dst[:0], src) + } + }) +} + +func BenchmarkAppendUnquotedArgSlowPath(b *testing.B) { + src := []byte("D0%B4%20%D0%B0%D0%B2%D0%BB%D0%B4%D1%84%D1%8B%D0%B0%D0%BE%20%D1%84%D0%B2%D0%B6%D0%BB%D0%B4%D1%8B%20%D0%B0%D0%BE") + b.RunParallel(func(pb *testing.PB) { + var dst []byte + for pb.Next() { + dst = AppendUnquotedArg(dst[:0], src) + } + }) +} diff -Nru golang-github-valyala-fasthttp-20160617/client.go golang-github-valyala-fasthttp-1.31.0/client.go --- golang-github-valyala-fasthttp-20160617/client.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/client.go 2021-10-09 18:39:05.000000000 +0000 @@ -8,6 +8,7 @@ "fmt" "io" "net" + "strconv" "strings" "sync" "sync/atomic" @@ -19,13 +20,15 @@ // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // -// Response is ignored if resp is nil. -// // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // +// The function doesn't follow redirects. Use Get* for following redirects. +// +// Response is ignored if resp is nil. +// // ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections // to the requested host are busy. // @@ -46,13 +49,23 @@ // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // +// The function doesn't follow redirects. Use Get* for following redirects. +// // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned during // the given timeout. // +// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections +// to the requested host are busy. +// // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. +// +// Warning: DoTimeout does not terminate the request itself. The request will +// continue in the background and the response will be discarded. +// If requests take too long and the connection pool gets filled up please +// try using a Client and setting a ReadTimeout. func DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return defaultClient.DoTimeout(req, resp, timeout) } @@ -68,27 +81,62 @@ // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // +// The function doesn't follow redirects. Use Get* for following redirects. +// // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned until // the given deadline. // +// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections +// to the requested host are busy. +// // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func DoDeadline(req *Request, resp *Response, deadline time.Time) error { return defaultClient.DoDeadline(req, resp, deadline) } -// Get appends url contents to dst and returns it as body. +// DoRedirects performs the given http request and fills the given http response, +// following up to maxRedirectsCount redirects. When the redirect count exceeds +// maxRedirectsCount, ErrTooManyRedirects is returned. +// +// Request must contain at least non-zero RequestURI with full url (including +// scheme and host) or non-zero Host header + RequestURI. +// +// Client determines the server to be requested in the following order: +// +// - from RequestURI if it contains full url with scheme and host; +// - from Host header otherwise. +// +// Response is ignored if resp is nil. +// +// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections +// to the requested host are busy. +// +// It is recommended obtaining req and resp via AcquireRequest +// and AcquireResponse in performance-critical code. +func DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error { + _, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, &defaultClient) + return err +} + +// Get returns the status code and body of url. +// +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. // -// New body buffer is allocated if dst is nil. +// The function follows redirects. Use Do* for manually handling redirects. func Get(dst []byte, url string) (statusCode int, body []byte, err error) { return defaultClient.Get(dst, url) } -// GetTimeout appends url contents to dst and returns it as body. +// GetTimeout returns the status code and body of url. // -// New body buffer is allocated if dst is nil. +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. +// +// The function follows redirects. Use Do* for manually handling redirects. // // ErrTimeout error is returned if url contents couldn't be fetched // during the given timeout. @@ -96,9 +144,12 @@ return defaultClient.GetTimeout(dst, url, timeout) } -// GetDeadline appends url contents to dst and returns it as body. +// GetDeadline returns the status code and body of url. +// +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. // -// New body buffer is allocated if dst is nil. +// The function follows redirects. Use Do* for manually handling redirects. // // ErrTimeout error is returned if url contents couldn't be fetched // until the given deadline. @@ -108,9 +159,10 @@ // Post sends POST request to the given url with the given POST arguments. // -// Response body is appended to dst, which is returned as body. +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. // -// New body buffer is allocated if dst is nil. +// The function follows redirects. Use Do* for manually handling redirects. // // Empty POST body is sent if postArgs is nil. func Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) { @@ -124,14 +176,20 @@ // Copying Client by value is prohibited. Create new instance instead. // // It is safe calling Client methods from concurrently running goroutines. +// +// The fields of a Client should not be changed while it is in use. type Client struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck // Client name. Used in User-Agent request header. // // Default client name is used if not set. Name string + // NoDefaultUserAgentHeader when set to true, causes the default + // User-Agent header to be excluded from the Request. + NoDefaultUserAgentHeader bool + // Callback for establishing new connections to hosts. // // Default Dial is used if not set. @@ -162,6 +220,16 @@ // after DefaultMaxIdleConnDuration. MaxIdleConnDuration time.Duration + // Keep-alive connections are closed after this duration. + // + // By default connection duration is unlimited. + MaxConnDuration time.Duration + + // Maximum number of attempts for idempotent calls + // + // DefaultMaxIdemponentCallAttempts is used if not set. + MaxIdemponentCallAttempts int + // Per-connection buffer size for responses' reading. // This also limits the maximum header size. // @@ -209,21 +277,48 @@ // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool - mLock sync.Mutex - m map[string]*HostClient - ms map[string]*HostClient + // Path values are sent as-is without normalization + // + // Disabled path normalization may be useful for proxying incoming requests + // to servers that are expecting paths to be forwarded as-is. + // + // By default path values are normalized, i.e. + // extra slashes are removed, special characters are encoded. + DisablePathNormalizing bool + + // Maximum duration for waiting for a free connection. + // + // By default will not waiting, return ErrNoFreeConns immediately + MaxConnWaitTimeout time.Duration + + // RetryIf controls whether a retry should be attempted after an error. + // + // By default will use isIdempotent function + RetryIf RetryIfFunc + + mLock sync.Mutex + m map[string]*HostClient + ms map[string]*HostClient + readerPool sync.Pool + writerPool sync.Pool } -// Get appends url contents to dst and returns it as body. +// Get returns the status code and body of url. // -// New body buffer is allocated if dst is nil. +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. +// +// The function follows redirects. Use Do* for manually handling redirects. func (c *Client) Get(dst []byte, url string) (statusCode int, body []byte, err error) { return clientGetURL(dst, url, c) } -// GetTimeout appends url contents to dst and returns it as body. +// GetTimeout returns the status code and body of url. +// +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. // -// New body buffer is allocated if dst is nil. +// The function follows redirects. Use Do* for manually handling redirects. // // ErrTimeout error is returned if url contents couldn't be fetched // during the given timeout. @@ -231,9 +326,12 @@ return clientGetURLTimeout(dst, url, timeout, c) } -// GetDeadline appends url contents to dst and returns it as body. +// GetDeadline returns the status code and body of url. // -// New body buffer is allocated if dst is nil. +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. +// +// The function follows redirects. Use Do* for manually handling redirects. // // ErrTimeout error is returned if url contents couldn't be fetched // until the given deadline. @@ -243,9 +341,10 @@ // Post sends POST request to the given url with the given POST arguments. // -// Response body is appended to dst, which is returned as body. +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. // -// New body buffer is allocated if dst is nil. +// The function follows redirects. Use Do* for manually handling redirects. // // Empty POST body is sent if postArgs is nil. func (c *Client) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) { @@ -263,13 +362,23 @@ // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // +// The function doesn't follow redirects. Use Get* for following redirects. +// // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned during // the given timeout. // +// ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections +// to the requested host are busy. +// // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. +// +// Warning: DoTimeout does not terminate the request itself. The request will +// continue in the background and the response will be discarded. +// If requests take too long and the connection pool gets filled up please +// try setting a ReadTimeout. func (c *Client) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return clientDoTimeout(req, resp, timeout, c) } @@ -285,29 +394,60 @@ // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // +// The function doesn't follow redirects. Use Get* for following redirects. +// // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned until // the given deadline. // +// ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections +// to the requested host are busy. +// // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *Client) DoDeadline(req *Request, resp *Response, deadline time.Time) error { return clientDoDeadline(req, resp, deadline, c) } -// Do performs the given http request and fills the given http response. +// DoRedirects performs the given http request and fills the given http response, +// following up to maxRedirectsCount redirects. When the redirect count exceeds +// maxRedirectsCount, ErrTooManyRedirects is returned. // // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // +// Client determines the server to be requested in the following order: +// +// - from RequestURI if it contains full url with scheme and host; +// - from Host header otherwise. +// // Response is ignored if resp is nil. // +// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections +// to the requested host are busy. +// +// It is recommended obtaining req and resp via AcquireRequest +// and AcquireResponse in performance-critical code. +func (c *Client) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error { + _, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c) + return err +} + +// Do performs the given http request and fills the given http response. +// +// Request must contain at least non-zero RequestURI with full url (including +// scheme and host) or non-zero Host header + RequestURI. +// // Client determines the server to be requested in the following order: // // - from RequestURI if it contains full url with scheme and host; // - from Host header otherwise. // +// Response is ignored if resp is nil. +// +// The function doesn't follow redirects. Use Get* for following redirects. +// // ErrNoFreeConns is returned if all Client.MaxConnsPerHost connections // to the requested host are busy. // @@ -315,6 +455,10 @@ // and AcquireResponse in performance-critical code. func (c *Client) Do(req *Request, resp *Response) error { uri := req.URI() + if uri == nil { + return ErrorInvalidURI + } + host := uri.Host() isTLS := false @@ -345,18 +489,26 @@ hc = &HostClient{ Addr: addMissingPort(string(host), isTLS), Name: c.Name, + NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader, Dial: c.Dial, DialDualStack: c.DialDualStack, IsTLS: isTLS, TLSConfig: c.TLSConfig, MaxConns: c.MaxConnsPerHost, MaxIdleConnDuration: c.MaxIdleConnDuration, + MaxConnDuration: c.MaxConnDuration, + MaxIdemponentCallAttempts: c.MaxIdemponentCallAttempts, ReadBufferSize: c.ReadBufferSize, WriteBufferSize: c.WriteBufferSize, ReadTimeout: c.ReadTimeout, WriteTimeout: c.WriteTimeout, MaxResponseBodySize: c.MaxResponseBodySize, DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing, + DisablePathNormalizing: c.DisablePathNormalizing, + MaxConnWaitTimeout: c.MaxConnWaitTimeout, + RetryIf: c.RetryIf, + clientReaderPool: &c.readerPool, + clientWriterPool: &c.writerPool, } m[string(host)] = hc if len(m) == 1 { @@ -372,13 +524,39 @@ return hc.Do(req, resp) } +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in a +// "keep-alive" state. It does not interrupt any connections currently +// in use. +func (c *Client) CloseIdleConnections() { + c.mLock.Lock() + for _, v := range c.m { + v.CloseIdleConnections() + } + for _, v := range c.ms { + v.CloseIdleConnections() + } + c.mLock.Unlock() +} + func (c *Client) mCleaner(m map[string]*HostClient) { mustStop := false + + sleep := c.MaxIdleConnDuration + if sleep < time.Second { + sleep = time.Second + } else if sleep > 10*time.Second { + sleep = 10 * time.Second + } + for { - t := time.Now() c.mLock.Lock() for k, v := range m { - if t.Sub(v.LastUseTime()) > time.Minute { + v.connsLock.Lock() + shouldRemove := v.connsCount == 0 + v.connsLock.Unlock() + + if shouldRemove { delete(m, k) } } @@ -390,7 +568,7 @@ if mustStop { break } - time.Sleep(10 * time.Second) + time.Sleep(sleep) } } @@ -403,6 +581,9 @@ // connection is closed. const DefaultMaxIdleConnDuration = 10 * time.Second +// DefaultMaxIdemponentCallAttempts is the default idempotent calls attempts count. +const DefaultMaxIdemponentCallAttempts = 5 + // DialFunc must establish connection to addr. // // There is no need in establishing TLS (SSL) connection for https. @@ -417,18 +598,29 @@ // - foobar.com:8080 type DialFunc func(addr string) (net.Conn, error) +// RetryIfFunc signature of retry if function +// +// Request argument passed to RetryIfFunc, if there are any request errors. +type RetryIfFunc func(request *Request) bool + +// TransportFunc wraps every request/response. +type TransportFunc func(*Request, *Response) error + // HostClient balances http requests among hosts listed in Addr. // // HostClient may be used for balancing load among multiple upstream hosts. +// While multiple addresses passed to HostClient.Addr may be used for balancing +// load among them, it would be better using LBClient instead, since HostClient +// may unevenly balance load among upstream hosts. // // It is forbidden copying HostClient instances. Create new instances instead. // // It is safe calling HostClient methods from concurrently running goroutines. type HostClient struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck // Comma-separated list of upstream HTTP server host addresses, - // which are passed to Dial in round-robin manner. + // which are passed to Dial in a round-robin manner. // // Each address may contain port if default dialer is used. // For example, @@ -441,6 +633,10 @@ // Client name. Used in User-Agent request header. Name string + // NoDefaultUserAgentHeader when set to true, causes the default + // User-Agent header to be excluded from the Request. + NoDefaultUserAgentHeader bool + // Callback for establishing new connection to the host. // // Default Dial is used if not set. @@ -465,6 +661,9 @@ // Maximum number of connections which may be established to all hosts // listed in Addr. // + // You can change this value while the HostClient is being used + // using HostClient.SetMaxConns(value) + // // DefaultMaxConnsPerHost is used if not set. MaxConns int @@ -479,6 +678,11 @@ // after DefaultMaxIdleConnDuration. MaxIdleConnDuration time.Duration + // Maximum number of attempts for idempotent calls + // + // DefaultMaxIdemponentCallAttempts is used if not set. + MaxIdemponentCallAttempts int + // Per-connection buffer size for responses' reading. // This also limits the maximum header size. // @@ -526,19 +730,60 @@ // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool + // Path values are sent as-is without normalization + // + // Disabled path normalization may be useful for proxying incoming requests + // to servers that are expecting paths to be forwarded as-is. + // + // By default path values are normalized, i.e. + // extra slashes are removed, special characters are encoded. + DisablePathNormalizing bool + + // Will not log potentially sensitive content in error logs + // + // This option is useful for servers that handle sensitive data + // in the request/response. + // + // Client logs full errors by default. + SecureErrorLogMessage bool + + // Maximum duration for waiting for a free connection. + // + // By default will not waiting, return ErrNoFreeConns immediately + MaxConnWaitTimeout time.Duration + + // RetryIf controls whether a retry should be attempted after an error. + // + // By default will use isIdempotent function + RetryIf RetryIfFunc + + // Transport defines a transport-like mechanism that wraps every request/response. + Transport TransportFunc + clientName atomic.Value lastUseTime uint32 connsLock sync.Mutex connsCount int conns []*clientConn + connsWait *wantConnQueue addrsLock sync.Mutex addrs []string addrIdx uint32 + tlsConfigMap map[string]*tls.Config + tlsConfigMapLock sync.Mutex + readerPool sync.Pool writerPool sync.Pool + + clientReaderPool *sync.Pool + clientWriterPool *sync.Pool + + pendingRequests int32 + + connsCleanerRun bool } type clientConn struct { @@ -546,9 +791,6 @@ createdTime time.Time lastUseTime time.Time - - lastReadDeadlineTime time.Time - lastWriteDeadlineTime time.Time } var startTimeUnix = time.Now().Unix() @@ -559,16 +801,22 @@ return time.Unix(startTimeUnix+int64(n), 0) } -// Get appends url contents to dst and returns it as body. +// Get returns the status code and body of url. // -// New body buffer is allocated if dst is nil. +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. +// +// The function follows redirects. Use Do* for manually handling redirects. func (c *HostClient) Get(dst []byte, url string) (statusCode int, body []byte, err error) { return clientGetURL(dst, url, c) } -// GetTimeout appends url contents to dst and returns it as body. +// GetTimeout returns the status code and body of url. +// +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. // -// New body buffer is allocated if dst is nil. +// The function follows redirects. Use Do* for manually handling redirects. // // ErrTimeout error is returned if url contents couldn't be fetched // during the given timeout. @@ -576,9 +824,12 @@ return clientGetURLTimeout(dst, url, timeout, c) } -// GetDeadline appends url contents to dst and returns it as body. +// GetDeadline returns the status code and body of url. // -// New body buffer is allocated if dst is nil. +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. +// +// The function follows redirects. Use Do* for manually handling redirects. // // ErrTimeout error is returned if url contents couldn't be fetched // until the given deadline. @@ -588,9 +839,10 @@ // Post sends POST request to the given url with the given POST arguments. // -// Response body is appended to dst, which is returned as body. +// The contents of dst will be replaced by the body and returned, if the dst +// is too small a new slice will be allocated. // -// New body buffer is allocated if dst is nil. +// The function follows redirects. Use Do* for manually handling redirects. // // Empty POST body is sent if postArgs is nil. func (c *HostClient) Post(dst []byte, url string, postArgs *Args) (statusCode int, body []byte, err error) { @@ -604,7 +856,7 @@ func clientGetURL(dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) { req := AcquireRequest() - statusCode, body, err = doRequestFollowRedirects(req, dst, url, c) + statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c) ReleaseRequest(req) return statusCode, body, err @@ -615,25 +867,13 @@ return clientGetURLDeadline(dst, url, deadline, c) } -func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) { - var sleepTime time.Duration - for { - statusCode, body, err = clientGetURLDeadlineFreeConn(dst, url, deadline, c) - if err != ErrNoFreeConns { - return statusCode, body, err - } - sleepTime = updateSleepTime(sleepTime, deadline) - time.Sleep(sleepTime) - } -} - type clientURLResponse struct { statusCode int body []byte err error } -func clientGetURLDeadlineFreeConn(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) { +func clientGetURLDeadline(dst []byte, url string, deadline time.Time, c clientDoer) (statusCode int, body []byte, err error) { timeout := -time.Since(deadline) if timeout <= 0 { return 0, dst, ErrTimeout @@ -646,8 +886,6 @@ } ch = chv.(chan clientURLResponse) - req := AcquireRequest() - // Note that the request continues execution on ErrTimeout until // client-specific ReadTimeout exceeds. This helps limiting load // on slow hosts by MaxConns* concurrent requests. @@ -655,28 +893,55 @@ // Without this 'hack' the load on slow host could exceed MaxConns* // concurrent requests, since timed out requests on client side // usually continue execution on the host. + + var mu sync.Mutex + var timedout, responded bool + go func() { - statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirects(req, dst, url, c) - ch <- clientURLResponse{ - statusCode: statusCodeCopy, - body: bodyCopy, - err: errCopy, + req := AcquireRequest() + + statusCodeCopy, bodyCopy, errCopy := doRequestFollowRedirectsBuffer(req, dst, url, c) + mu.Lock() + { + if !timedout { + ch <- clientURLResponse{ + statusCode: statusCodeCopy, + body: bodyCopy, + err: errCopy, + } + responded = true + } } + mu.Unlock() + + ReleaseRequest(req) }() - tc := acquireTimer(timeout) + tc := AcquireTimer(timeout) select { case resp := <-ch: - ReleaseRequest(req) - clientURLResponseChPool.Put(chv) statusCode = resp.statusCode body = resp.body err = resp.err case <-tc.C: - body = dst - err = ErrTimeout + mu.Lock() + { + if responded { + resp := <-ch + statusCode = resp.statusCode + body = resp.body + err = resp.err + } else { + timedout = true + err = ErrTimeout + body = dst + } + } + mu.Unlock() } - releaseTimer(tc) + ReleaseTimer(tc) + + clientURLResponseChPool.Put(chv) return statusCode, body, err } @@ -685,64 +950,81 @@ func clientPostURL(dst []byte, url string, postArgs *Args, c clientDoer) (statusCode int, body []byte, err error) { req := AcquireRequest() - req.Header.SetMethodBytes(strPost) + req.Header.SetMethod(MethodPost) req.Header.SetContentTypeBytes(strPostArgsContentType) if postArgs != nil { - postArgs.WriteTo(req.BodyWriter()) + if _, err := postArgs.WriteTo(req.BodyWriter()); err != nil { + return 0, nil, err + } } - statusCode, body, err = doRequestFollowRedirects(req, dst, url, c) + statusCode, body, err = doRequestFollowRedirectsBuffer(req, dst, url, c) ReleaseRequest(req) return statusCode, body, err } var ( - errMissingLocation = errors.New("missing Location header for http redirect") - errTooManyRedirects = errors.New("too many redirects detected when doing the request") + // ErrMissingLocation is returned by clients when the Location header is missing on + // an HTTP response with a redirect status code. + ErrMissingLocation = errors.New("missing Location header for http redirect") + // ErrTooManyRedirects is returned by clients when the number of redirects followed + // exceed the max count. + ErrTooManyRedirects = errors.New("too many redirects detected when doing the request") + + // HostClients are only able to follow redirects to the same protocol. + ErrHostClientRedirectToDifferentScheme = errors.New("HostClient can't follow redirects to a different protocol, please use Client instead") ) -const maxRedirectsCount = 16 +const defaultMaxRedirectsCount = 16 -func doRequestFollowRedirects(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) { +func doRequestFollowRedirectsBuffer(req *Request, dst []byte, url string, c clientDoer) (statusCode int, body []byte, err error) { resp := AcquireResponse() bodyBuf := resp.bodyBuffer() resp.keepBodyBuffer = true oldBody := bodyBuf.B bodyBuf.B = dst + statusCode, _, err = doRequestFollowRedirects(req, resp, url, defaultMaxRedirectsCount, c) + + body = bodyBuf.B + bodyBuf.B = oldBody + resp.keepBodyBuffer = false + ReleaseResponse(resp) + + return statusCode, body, err +} + +func doRequestFollowRedirects(req *Request, resp *Response, url string, maxRedirectsCount int, c clientDoer) (statusCode int, body []byte, err error) { redirectsCount := 0 + for { - req.parsedURI = false - req.Header.host = req.Header.host[:0] req.SetRequestURI(url) + if err := req.parseURI(); err != nil { + return 0, nil, err + } if err = c.Do(req, resp); err != nil { break } statusCode = resp.Header.StatusCode() - if statusCode != StatusMovedPermanently && statusCode != StatusFound && statusCode != StatusSeeOther { + if !StatusCodeIsRedirect(statusCode) { break } redirectsCount++ if redirectsCount > maxRedirectsCount { - err = errTooManyRedirects + err = ErrTooManyRedirects break } location := resp.Header.peek(strLocation) if len(location) == 0 { - err = errMissingLocation + err = ErrMissingLocation break } url = getRedirectURL(url, location) } - body = bodyBuf.B - bodyBuf.B = oldBody - resp.keepBodyBuffer = false - ReleaseResponse(resp) - return statusCode, body, err } @@ -755,6 +1037,15 @@ return redirectURL } +// StatusCodeIsRedirect returns true if the status code indicates a redirect. +func StatusCodeIsRedirect(statusCode int) bool { + return statusCode == StatusMovedPermanently || + statusCode == StatusFound || + statusCode == StatusSeeOther || + statusCode == StatusTemporaryRedirect || + statusCode == StatusPermanentRedirect +} + var ( requestPool sync.Pool responsePool sync.Pool @@ -810,13 +1101,23 @@ // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // +// The function doesn't follow redirects. Use Get* for following redirects. +// // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned during // the given timeout. // +// ErrNoFreeConns is returned if all HostClient.MaxConns connections +// to the host are busy. +// // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. +// +// Warning: DoTimeout does not terminate the request itself. The request will +// continue in the background and the response will be discarded. +// If requests take too long and the connection pool gets filled up please +// try setting a ReadTimeout. func (c *HostClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return clientDoTimeout(req, resp, timeout, c) } @@ -827,57 +1128,52 @@ // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // +// The function doesn't follow redirects. Use Get* for following redirects. +// // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned until // the given deadline. // +// ErrNoFreeConns is returned if all HostClient.MaxConns connections +// to the host are busy. +// // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { return clientDoDeadline(req, resp, deadline, c) } +// DoRedirects performs the given http request and fills the given http response, +// following up to maxRedirectsCount redirects. When the redirect count exceeds +// maxRedirectsCount, ErrTooManyRedirects is returned. +// +// Request must contain at least non-zero RequestURI with full url (including +// scheme and host) or non-zero Host header + RequestURI. +// +// Client determines the server to be requested in the following order: +// +// - from RequestURI if it contains full url with scheme and host; +// - from Host header otherwise. +// +// Response is ignored if resp is nil. +// +// ErrNoFreeConns is returned if all DefaultMaxConnsPerHost connections +// to the requested host are busy. +// +// It is recommended obtaining req and resp via AcquireRequest +// and AcquireResponse in performance-critical code. +func (c *HostClient) DoRedirects(req *Request, resp *Response, maxRedirectsCount int) error { + _, _, err := doRequestFollowRedirects(req, resp, req.URI().String(), maxRedirectsCount, c) + return err +} + func clientDoTimeout(req *Request, resp *Response, timeout time.Duration, c clientDoer) error { deadline := time.Now().Add(timeout) return clientDoDeadline(req, resp, deadline, c) } func clientDoDeadline(req *Request, resp *Response, deadline time.Time, c clientDoer) error { - var sleepTime time.Duration - for { - err := clientDoDeadlineFreeConn(req, resp, deadline, c) - if err != ErrNoFreeConns { - return err - } - sleepTime = updateSleepTime(sleepTime, deadline) - time.Sleep(sleepTime) - } -} - -var sleepJitter uint64 - -func updateSleepTime(prevTime time.Duration, deadline time.Time) time.Duration { - sleepTime := prevTime * 2 - if sleepTime == 0 { - jitter := atomic.AddUint64(&sleepJitter, 1) % 40 - sleepTime = (10 + time.Duration(jitter)) * time.Millisecond - } - - remainingTime := deadline.Sub(time.Now()) - if sleepTime >= remainingTime { - // Just sleep for the remaining time and then time out. - // This should save CPU time for real work by other goroutines. - sleepTime = remainingTime + 10*time.Millisecond - if sleepTime < 0 { - sleepTime = 10 * time.Millisecond - } - } - - return sleepTime -} - -func clientDoDeadlineFreeConn(req *Request, resp *Response, deadline time.Time, c clientDoer) error { timeout := -time.Since(deadline) if timeout <= 0 { return ErrTimeout @@ -896,6 +1192,11 @@ req.copyToSkipBody(reqCopy) swapRequestBody(req, reqCopy) respCopy := AcquireResponse() + if resp != nil { + // Not calling resp.copyToSkipBody(respCopy) here to avoid + // unexpected messing with headers + respCopy.SkipBody = resp.SkipBody + } // Note that the request continues execution on ErrTimeout until // client-specific ReadTimeout exceeds. This helps limiting load @@ -904,25 +1205,50 @@ // Without this 'hack' the load on slow host could exceed MaxConns* // concurrent requests, since timed out requests on client side // usually continue execution on the host. + + var mu sync.Mutex + var timedout, responded bool + go func() { - ch <- c.Do(reqCopy, respCopy) + reqCopy.timeout = timeout + errDo := c.Do(reqCopy, respCopy) + mu.Lock() + { + if !timedout { + if resp != nil { + respCopy.copyToSkipBody(resp) + swapResponseBody(resp, respCopy) + } + swapRequestBody(reqCopy, req) + ch <- errDo + responded = true + } + } + mu.Unlock() + + ReleaseResponse(respCopy) + ReleaseRequest(reqCopy) }() - tc := acquireTimer(timeout) + tc := AcquireTimer(timeout) var err error select { case err = <-ch: - if resp != nil { - respCopy.copyToSkipBody(resp) - swapResponseBody(resp, respCopy) - } - ReleaseResponse(respCopy) - ReleaseRequest(reqCopy) - errorChPool.Put(chv) case <-tc.C: - err = ErrTimeout + mu.Lock() + { + if responded { + err = <-ch + } else { + timedout = true + err = ErrTimeout + } + } + mu.Unlock() } - releaseTimer(tc) + ReleaseTimer(tc) + + errorChPool.Put(chv) return err } @@ -934,6 +1260,8 @@ // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // +// The function doesn't follow redirects. Use Get* for following redirects. +// // Response is ignored if resp is nil. // // ErrNoFreeConns is returned if all HostClient.MaxConns connections @@ -942,16 +1270,63 @@ // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *HostClient) Do(req *Request, resp *Response) error { - retry, err := c.do(req, resp) - if err != nil && retry && isIdempotent(req) { - _, err = c.do(req, resp) + var err error + var retry bool + maxAttempts := c.MaxIdemponentCallAttempts + if maxAttempts <= 0 { + maxAttempts = DefaultMaxIdemponentCallAttempts + } + isRequestRetryable := isIdempotent + if c.RetryIf != nil { + isRequestRetryable = c.RetryIf + } + attempts := 0 + hasBodyStream := req.IsBodyStream() + + atomic.AddInt32(&c.pendingRequests, 1) + for { + retry, err = c.do(req, resp) + if err == nil || !retry { + break + } + + if hasBodyStream { + break + } + if !isRequestRetryable(req) { + // Retry non-idempotent requests if the server closes + // the connection before sending the response. + // + // This case is possible if the server closes the idle + // keep-alive connection on timeout. + // + // Apache and nginx usually do this. + if err != io.EOF { + break + } + } + attempts++ + if attempts >= maxAttempts { + break + } } + atomic.AddInt32(&c.pendingRequests, -1) + if err == io.EOF { err = ErrConnectionClosed } return err } +// PendingRequests returns the current number of requests the client +// is executing. +// +// This function may be used for balancing load among multiple HostClient +// instances. +func (c *HostClient) PendingRequests() int { + return int(atomic.LoadInt32(&c.pendingRequests)) +} + func isIdempotent(req *Request) bool { return req.Header.IsGet() || req.Header.IsHead() || req.Header.IsPut() } @@ -980,29 +1355,53 @@ panic("BUG: resp cannot be nil") } + // Secure header error logs configuration + resp.secureErrorLogMessage = c.SecureErrorLogMessage + resp.Header.secureErrorLogMessage = c.SecureErrorLogMessage + req.secureErrorLogMessage = c.SecureErrorLogMessage + req.Header.secureErrorLogMessage = c.SecureErrorLogMessage + + if c.IsTLS != bytes.Equal(req.uri.Scheme(), strHTTPS) { + return false, ErrHostClientRedirectToDifferentScheme + } + atomic.StoreUint32(&c.lastUseTime, uint32(time.Now().Unix()-startTimeUnix)) // Free up resources occupied by response before sending the request, // so the GC may reclaim these resources (e.g. response body). + + // backing up SkipBody in case it was set explicitly + customSkipBody := resp.SkipBody resp.Reset() + resp.SkipBody = customSkipBody - cc, err := c.acquireConn() - if err != nil { - return false, err - } - conn := cc.c + req.URI().DisablePathNormalizing = c.DisablePathNormalizing + + userAgentOld := req.Header.UserAgent() + if len(userAgentOld) == 0 { + req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...) + } + + if c.Transport != nil { + err := c.Transport(req, resp) + return err == nil, err + } + + cc, err := c.acquireConn(req.timeout, req.ConnectionClose()) + if err != nil { + return false, err + } + conn := cc.c + + resp.parseNetConn(conn) if c.WriteTimeout > 0 { - // Optimization: update write deadline only if more than 25% - // of the last write deadline exceeded. - // See https://github.com/golang/go/issues/15133 for details. + // Set Deadline every time, since golang has fixed the performance issue + // See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details currentTime := time.Now() - if currentTime.Sub(cc.lastWriteDeadlineTime) > (c.WriteTimeout >> 2) { - if err = conn.SetWriteDeadline(currentTime.Add(c.WriteTimeout)); err != nil { - c.closeConn(cc) - return true, err - } - cc.lastWriteDeadlineTime = currentTime + if err = conn.SetWriteDeadline(currentTime.Add(c.WriteTimeout)); err != nil { + c.closeConn(cc) + return true, err } } @@ -1012,15 +1411,8 @@ resetConnection = true } - userAgentOld := req.Header.UserAgent() - if len(userAgentOld) == 0 { - req.Header.userAgent = c.getClientName() - } bw := c.acquireWriter(conn) err = req.Write(bw) - if len(userAgentOld) == 0 { - req.Header.userAgent = userAgentOld - } if resetConnection { req.Header.ResetConnectionClose() @@ -1037,20 +1429,16 @@ c.releaseWriter(bw) if c.ReadTimeout > 0 { - // Optimization: update read deadline only if more than 25% - // of the last read deadline exceeded. - // See https://github.com/golang/go/issues/15133 for details. + // Set Deadline every time, since golang has fixed the performance issue + // See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details currentTime := time.Now() - if currentTime.Sub(cc.lastReadDeadlineTime) > (c.ReadTimeout >> 2) { - if err = conn.SetReadDeadline(currentTime.Add(c.ReadTimeout)); err != nil { - c.closeConn(cc) - return true, err - } - cc.lastReadDeadlineTime = currentTime + if err = conn.SetReadDeadline(currentTime.Add(c.ReadTimeout)); err != nil { + c.closeConn(cc) + return true, err } } - if !req.Header.IsGet() && req.Header.IsHead() { + if customSkipBody || req.Header.IsHead() { resp.SkipBody = true } if c.DisableHeaderNamesNormalizing { @@ -1061,10 +1449,9 @@ if err = resp.ReadLimitBody(br, c.MaxResponseBodySize); err != nil { c.releaseReader(br) c.closeConn(cc) - if err == io.EOF { - return true, err - } - return false, err + // Don't retry in case of ErrBodyTooLarge since we will just get the same again. + retry := err != ErrBodyTooLarge + return retry, err } c.releaseReader(br) @@ -1080,11 +1467,11 @@ var ( // ErrNoFreeConns is returned when no free connections available // to the given host. + // + // Increase the allowed number of connections per host if you + // see this error. ErrNoFreeConns = errors.New("no free connections available to host") - // ErrTimeout is returned from timed out calls. - ErrTimeout = errors.New("timeout") - // ErrConnectionClosed may be returned from client methods if the server // closes connection before returning the first response byte. // @@ -1096,8 +1483,31 @@ "Make sure the server returns 'Connection: close' response header before closing the connection") ) -func (c *HostClient) acquireConn() (*clientConn, error) { - var cc *clientConn +type timeoutError struct{} + +func (e *timeoutError) Error() string { + return "timeout" +} + +// Only implement the Timeout() function of the net.Error interface. +// This allows for checks like: +// +// if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() { +func (e *timeoutError) Timeout() bool { + return true +} + +// ErrTimeout is returned from timed out calls. +var ErrTimeout = &timeoutError{} + +// SetMaxConns sets up the maximum number of connections which may be established to all hosts listed in Addr. +func (c *HostClient) SetMaxConns(newMaxConns int) { + c.connsLock.Lock() + c.MaxConns = newMaxConns + c.connsLock.Unlock() +} + +func (c *HostClient) acquireConn(reqTimeout time.Duration, connectionClose bool) (cc *clientConn, err error) { createConn := false startCleaner := false @@ -1112,13 +1522,15 @@ if c.connsCount < maxConns { c.connsCount++ createConn = true - } - if createConn && c.connsCount == 1 { - startCleaner = true + if !c.connsCleanerRun && !connectionClose { + startCleaner = true + c.connsCleanerRun = true + } } } else { n-- cc = c.conns[n] + c.conns[n] = nil c.conns = c.conns[:n] } c.connsLock.Unlock() @@ -1127,7 +1539,51 @@ return cc, nil } if !createConn { - return nil, ErrNoFreeConns + if c.MaxConnWaitTimeout <= 0 { + return nil, ErrNoFreeConns + } + + // reqTimeout c.MaxConnWaitTimeout wait duration + // d1 d2 min(d1, d2) + // 0(not set) d2 d2 + // d1 0(don't wait) 0(don't wait) + // 0(not set) d2 d2 + timeout := c.MaxConnWaitTimeout + timeoutOverridden := false + // reqTimeout == 0 means not set + if reqTimeout > 0 && reqTimeout < timeout { + timeout = reqTimeout + timeoutOverridden = true + } + + // wait for a free connection + tc := AcquireTimer(timeout) + defer ReleaseTimer(tc) + + w := &wantConn{ + ready: make(chan struct{}, 1), + } + defer func() { + if err != nil { + w.cancel(c, err) + } + }() + + c.queueForIdle(w) + + select { + case <-w.ready: + return w.conn, w.err + case <-tc.C: + if timeoutOverridden { + return nil, ErrTimeout + } + return nil, ErrNoFreeConns + } + } + + if startCleaner { + go c.connsCleaner() } conn, err := c.dialHostHard() @@ -1137,16 +1593,56 @@ } cc = acquireClientConn(conn) - if startCleaner { - go c.connsCleaner() - } return cc, nil } +func (c *HostClient) queueForIdle(w *wantConn) { + c.connsLock.Lock() + defer c.connsLock.Unlock() + if c.connsWait == nil { + c.connsWait = &wantConnQueue{} + } + c.connsWait.clearFront() + c.connsWait.pushBack(w) +} + +func (c *HostClient) dialConnFor(w *wantConn) { + conn, err := c.dialHostHard() + if err != nil { + w.tryDeliver(nil, err) + c.decConnsCount() + return + } + + cc := acquireClientConn(conn) + delivered := w.tryDeliver(cc, nil) + if !delivered { + // not delivered, return idle connection + c.releaseConn(cc) + } +} + +// CloseIdleConnections closes any connections which were previously +// connected from previous requests but are now sitting idle in a +// "keep-alive" state. It does not interrupt any connections currently +// in use. +func (c *HostClient) CloseIdleConnections() { + c.connsLock.Lock() + scratch := append([]*clientConn{}, c.conns...) + for i := range c.conns { + c.conns[i] = nil + } + c.conns = c.conns[:0] + c.connsLock.Unlock() + + for _, cc := range scratch { + c.closeConn(cc) + } +} + func (c *HostClient) connsCleaner() { var ( scratch []*clientConn - mustStop bool maxIdleConnDuration = c.MaxIdleConnDuration ) if maxIdleConnDuration <= 0 { @@ -1155,6 +1651,7 @@ for { currentTime := time.Now() + // Determine idle connections to be closed. c.connsLock.Lock() conns := c.conns n := len(conns) @@ -1162,7 +1659,12 @@ for i < n && currentTime.Sub(conns[i].lastUseTime) > maxIdleConnDuration { i++ } - mustStop = (c.connsCount == i) + sleepFor := maxIdleConnDuration + if i < n { + // + 1 so we actually sleep past the expiration time and not up to it. + // Otherwise the > check above would still fail. + sleepFor = maxIdleConnDuration - currentTime.Sub(conns[i].lastUseTime) + 1 + } scratch = append(scratch[:0], conns[:i]...) if i > 0 { m := copy(conns, conns[i:]) @@ -1173,14 +1675,24 @@ } c.connsLock.Unlock() + // Close idle connections. for i, cc := range scratch { c.closeConn(cc) scratch[i] = nil } + + // Determine whether to stop the connsCleaner. + c.connsLock.Lock() + mustStop := c.connsCount == 0 + if mustStop { + c.connsCleanerRun = false + } + c.connsLock.Unlock() if mustStop { break } - time.Sleep(maxIdleConnDuration) + + time.Sleep(sleepFor) } } @@ -1191,9 +1703,37 @@ } func (c *HostClient) decConnsCount() { + if c.MaxConnWaitTimeout <= 0 { + c.connsLock.Lock() + c.connsCount-- + c.connsLock.Unlock() + return + } + c.connsLock.Lock() - c.connsCount-- - c.connsLock.Unlock() + defer c.connsLock.Unlock() + dialed := false + if q := c.connsWait; q != nil && q.len() > 0 { + for q.len() > 0 { + w := q.popFront() + if w.waiting() { + go c.dialConnFor(w) + dialed = true + break + } + } + } + if !dialed { + c.connsCount-- + } +} + +// ConnsCount returns connection count of HostClient +func (c *HostClient) ConnsCount() int { + c.connsLock.Lock() + defer c.connsLock.Unlock() + + return c.connsCount } func acquireClientConn(conn net.Conn) *clientConn { @@ -1208,7 +1748,8 @@ } func releaseClientConn(cc *clientConn) { - cc.c = nil + // Reset all fields. + *cc = clientConn{} clientConnPool.Put(cc) } @@ -1216,52 +1757,132 @@ func (c *HostClient) releaseConn(cc *clientConn) { cc.lastUseTime = time.Now() + if c.MaxConnWaitTimeout <= 0 { + c.connsLock.Lock() + c.conns = append(c.conns, cc) + c.connsLock.Unlock() + return + } + + // try to deliver an idle connection to a *wantConn c.connsLock.Lock() - c.conns = append(c.conns, cc) - c.connsLock.Unlock() + defer c.connsLock.Unlock() + delivered := false + if q := c.connsWait; q != nil && q.len() > 0 { + for q.len() > 0 { + w := q.popFront() + if w.waiting() { + delivered = w.tryDeliver(cc, nil) + break + } + } + } + if !delivered { + c.conns = append(c.conns, cc) + } } func (c *HostClient) acquireWriter(conn net.Conn) *bufio.Writer { - v := c.writerPool.Get() - if v == nil { - n := c.WriteBufferSize - if n <= 0 { - n = defaultWriteBufferSize + var v interface{} + if c.clientWriterPool != nil { + v = c.clientWriterPool.Get() + if v == nil { + n := c.WriteBufferSize + if n <= 0 { + n = defaultWriteBufferSize + } + return bufio.NewWriterSize(conn, n) + } + } else { + v = c.writerPool.Get() + if v == nil { + n := c.WriteBufferSize + if n <= 0 { + n = defaultWriteBufferSize + } + return bufio.NewWriterSize(conn, n) } - return bufio.NewWriterSize(conn, n) } + bw := v.(*bufio.Writer) bw.Reset(conn) return bw } func (c *HostClient) releaseWriter(bw *bufio.Writer) { - c.writerPool.Put(bw) + if c.clientWriterPool != nil { + c.clientWriterPool.Put(bw) + } else { + c.writerPool.Put(bw) + } } func (c *HostClient) acquireReader(conn net.Conn) *bufio.Reader { - v := c.readerPool.Get() - if v == nil { - n := c.ReadBufferSize - if n <= 0 { - n = defaultReadBufferSize + var v interface{} + if c.clientReaderPool != nil { + v = c.clientReaderPool.Get() + if v == nil { + n := c.ReadBufferSize + if n <= 0 { + n = defaultReadBufferSize + } + return bufio.NewReaderSize(conn, n) + } + } else { + v = c.readerPool.Get() + if v == nil { + n := c.ReadBufferSize + if n <= 0 { + n = defaultReadBufferSize + } + return bufio.NewReaderSize(conn, n) } - return bufio.NewReaderSize(conn, n) } + br := v.(*bufio.Reader) br.Reset(conn) return br } func (c *HostClient) releaseReader(br *bufio.Reader) { - c.readerPool.Put(br) + if c.clientReaderPool != nil { + c.clientReaderPool.Put(br) + } else { + c.readerPool.Put(br) + } +} + +func newClientTLSConfig(c *tls.Config, addr string) *tls.Config { + if c == nil { + c = &tls.Config{} + } else { + c = c.Clone() + } + + if c.ClientSessionCache == nil { + c.ClientSessionCache = tls.NewLRUClientSessionCache(0) + } + + if len(c.ServerName) == 0 { + serverName := tlsServerName(addr) + if serverName == "*" { + c.InsecureSkipVerify = true + } else { + c.ServerName = serverName + } + } + return c } -func newDefaultTLSConfig() *tls.Config { - return &tls.Config{ - InsecureSkipVerify: true, - ClientSessionCache: tls.NewLRUClientSessionCache(0), +func tlsServerName(addr string) string { + if !strings.Contains(addr, ":") { + return addr + } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return "*" } + return host } func (c *HostClient) nextAddr() string { @@ -1297,7 +1918,8 @@ deadline := time.Now().Add(timeout) for n > 0 { addr := c.nextAddr() - conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, c.TLSConfig) + tlsConfig := c.cachedTLSConfig(addr) + conn, err = dialAddr(addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout) if err == nil { return conn, nil } @@ -1309,7 +1931,63 @@ return nil, err } -func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config) (net.Conn, error) { +func (c *HostClient) cachedTLSConfig(addr string) *tls.Config { + if !c.IsTLS { + return nil + } + + c.tlsConfigMapLock.Lock() + if c.tlsConfigMap == nil { + c.tlsConfigMap = make(map[string]*tls.Config) + } + cfg := c.tlsConfigMap[addr] + if cfg == nil { + cfg = newClientTLSConfig(c.TLSConfig, addr) + c.tlsConfigMap[addr] = cfg + } + c.tlsConfigMapLock.Unlock() + + return cfg +} + +// ErrTLSHandshakeTimeout indicates there is a timeout from tls handshake. +var ErrTLSHandshakeTimeout = errors.New("tls handshake timed out") + +var timeoutErrorChPool sync.Pool + +func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { + tc := AcquireTimer(timeout) + defer ReleaseTimer(tc) + + var ch chan error + chv := timeoutErrorChPool.Get() + if chv == nil { + chv = make(chan error) + } + ch = chv.(chan error) + defer timeoutErrorChPool.Put(chv) + + conn := tls.Client(rawConn, tlsConfig) + + go func() { + ch <- conn.Handshake() + }() + + select { + case <-tc.C: + rawConn.Close() + <-ch + return nil, ErrTLSHandshakeTimeout + case err := <-ch: + if err != nil { + rawConn.Close() + return nil, err + } + return conn, nil + } +} + +func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { if dial == nil { if dialDualStack { dial = DialDualStack @@ -1325,11 +2003,12 @@ if conn == nil { panic("BUG: DialFunc returned (nil, nil)") } - if isTLS { - if tlsConfig == nil { - tlsConfig = newDefaultTLSConfig() + _, isTLSAlready := conn.(*tls.Conn) + if isTLS && !isTLSAlready { + if timeout == 0 { + return tls.Client(conn, tlsConfig), nil } - conn = tls.Client(conn, tlsConfig) + return tlsClientHandshake(conn, tlsConfig, timeout) } return conn, nil } @@ -1339,7 +2018,7 @@ var clientName []byte if v == nil { clientName = []byte(c.Name) - if len(clientName) == 0 { + if len(clientName) == 0 && !c.NoDefaultUserAgentHeader { clientName = defaultUserAgent } c.clientName.Store(clientName) @@ -1358,10 +2037,140 @@ if isTLS { port = 443 } - return fmt.Sprintf("%s:%d", addr, port) + return net.JoinHostPort(addr, strconv.Itoa(port)) } -// PipelineClient pipelines requests over a single connection to the given Addr. +// A wantConn records state about a wanted connection +// (that is, an active call to getConn). +// The conn may be gotten by dialing or by finding an idle connection, +// or a cancellation may make the conn no longer wanted. +// These three options are racing against each other and use +// wantConn to coordinate and agree about the winning outcome. +// +// inspired by net/http/transport.go +type wantConn struct { + ready chan struct{} + mu sync.Mutex // protects conn, err, close(ready) + conn *clientConn + err error +} + +// waiting reports whether w is still waiting for an answer (connection or error). +func (w *wantConn) waiting() bool { + select { + case <-w.ready: + return false + default: + return true + } +} + +// tryDeliver attempts to deliver conn, err to w and reports whether it succeeded. +func (w *wantConn) tryDeliver(conn *clientConn, err error) bool { + w.mu.Lock() + defer w.mu.Unlock() + + if w.conn != nil || w.err != nil { + return false + } + w.conn = conn + w.err = err + if w.conn == nil && w.err == nil { + panic("fasthttp: internal error: misuse of tryDeliver") + } + close(w.ready) + return true +} + +// cancel marks w as no longer wanting a result (for example, due to cancellation). +// If a connection has been delivered already, cancel returns it with c.releaseConn. +func (w *wantConn) cancel(c *HostClient, err error) { + w.mu.Lock() + if w.conn == nil && w.err == nil { + close(w.ready) // catch misbehavior in future delivery + } + + conn := w.conn + w.conn = nil + w.err = err + w.mu.Unlock() + + if conn != nil { + c.releaseConn(conn) + } +} + +// A wantConnQueue is a queue of wantConns. +// +// inspired by net/http/transport.go +type wantConnQueue struct { + // This is a queue, not a deque. + // It is split into two stages - head[headPos:] and tail. + // popFront is trivial (headPos++) on the first stage, and + // pushBack is trivial (append) on the second stage. + // If the first stage is empty, popFront can swap the + // first and second stages to remedy the situation. + // + // This two-stage split is analogous to the use of two lists + // in Okasaki's purely functional queue but without the + // overhead of reversing the list when swapping stages. + head []*wantConn + headPos int + tail []*wantConn +} + +// len returns the number of items in the queue. +func (q *wantConnQueue) len() int { + return len(q.head) - q.headPos + len(q.tail) +} + +// pushBack adds w to the back of the queue. +func (q *wantConnQueue) pushBack(w *wantConn) { + q.tail = append(q.tail, w) +} + +// popFront removes and returns the wantConn at the front of the queue. +func (q *wantConnQueue) popFront() *wantConn { + if q.headPos >= len(q.head) { + if len(q.tail) == 0 { + return nil + } + // Pick up tail as new head, clear tail. + q.head, q.headPos, q.tail = q.tail, 0, q.head[:0] + } + + w := q.head[q.headPos] + q.head[q.headPos] = nil + q.headPos++ + return w +} + +// peekFront returns the wantConn at the front of the queue without removing it. +func (q *wantConnQueue) peekFront() *wantConn { + if q.headPos < len(q.head) { + return q.head[q.headPos] + } + if len(q.tail) > 0 { + return q.tail[0] + } + return nil +} + +// cleanFront pops any wantConns that are no longer waiting from the head of the +// queue, reporting whether any were popped. +func (q *wantConnQueue) clearFront() (cleaned bool) { + for { + w := q.peekFront() + if w == nil || w.waiting() { + return cleaned + } + q.popFront() + cleaned = true + } +} + +// PipelineClient pipelines requests over a limited set of concurrent +// connections to the given Addr. // // This client may be used in highly loaded HTTP-based RPC systems for reducing // context switches and network level overhead. @@ -1373,12 +2182,25 @@ // It is safe calling PipelineClient methods from concurrently running // goroutines. type PipelineClient struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck // Address of the host to connect to. Addr string - // The maximum number of pending pipelined requests to the server. + // PipelineClient name. Used in User-Agent request header. + Name string + + // NoDefaultUserAgentHeader when set to true, causes the default + // User-Agent header to be excluded from the Request. + NoDefaultUserAgentHeader bool + + // The maximum number of concurrent connections to the Addr. + // + // A single connection is used by default. + MaxConns int + + // The maximum number of pending pipelined requests over + // a single connection to Addr. // // DefaultMaxPendingRequests is used by default. MaxPendingRequests int @@ -1404,6 +2226,33 @@ // since unfortunately ipv6 remains broken in many networks worldwide :) DialDualStack bool + // Response header names are passed as-is without normalization + // if this option is set. + // + // Disabled header names' normalization may be useful only for proxying + // responses to other clients expecting case-sensitive + // header names. See https://github.com/valyala/fasthttp/issues/57 + // for details. + // + // By default request and response header names are normalized, i.e. + // The first letter and the first letters following dashes + // are uppercased, while all the other letters are lowercased. + // Examples: + // + // * HOST -> Host + // * content-type -> Content-Type + // * cONTENT-lenGTH -> Content-Length + DisableHeaderNamesNormalizing bool + + // Path values are sent as-is without normalization + // + // Disabled path normalization may be useful for proxying incoming requests + // to servers that are expecting paths to be forwarded as-is. + // + // By default path values are normalized, i.e. + // extra slashes are removed, special characters are encoded. + DisablePathNormalizing bool + // Whether to use TLS (aka SSL or HTTPS) for host connections. IsTLS bool @@ -1442,11 +2291,40 @@ // By default standard logger from log package is used. Logger Logger + connClients []*pipelineConnClient + connClientsLock sync.Mutex +} + +type pipelineConnClient struct { + noCopy noCopy //nolint:unused,structcheck + + Addr string + Name string + NoDefaultUserAgentHeader bool + MaxPendingRequests int + MaxBatchDelay time.Duration + Dial DialFunc + DialDualStack bool + DisableHeaderNamesNormalizing bool + DisablePathNormalizing bool + IsTLS bool + TLSConfig *tls.Config + MaxIdleConnDuration time.Duration + ReadBufferSize int + WriteBufferSize int + ReadTimeout time.Duration + WriteTimeout time.Duration + Logger Logger + workPool sync.Pool chLock sync.Mutex chW chan *pipelineWork chR chan *pipelineWork + + tlsConfigLock sync.Mutex + tlsConfig *tls.Config + clientName atomic.Value } type pipelineWork struct { @@ -1466,6 +2344,8 @@ // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // +// The function doesn't follow redirects. +// // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned during @@ -1473,6 +2353,11 @@ // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. +// +// Warning: DoTimeout does not terminate the request itself. The request will +// continue in the background and the response will be discarded. +// If requests take too long and the connection pool gets filled up please +// try setting a ReadTimeout. func (c *PipelineClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { return c.DoDeadline(req, resp, time.Now().Add(timeout)) } @@ -1483,6 +2368,8 @@ // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // +// The function doesn't follow redirects. +// // Response is ignored if resp is nil. // // ErrTimeout is returned if the response wasn't returned until @@ -1491,6 +2378,10 @@ // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *PipelineClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { + return c.getConnClient().DoDeadline(req, resp, deadline) +} + +func (c *pipelineConnClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { c.init() timeout := -time.Since(deadline) @@ -1498,7 +2389,17 @@ return ErrTimeout } + if c.DisablePathNormalizing { + req.URI().DisablePathNormalizing = true + } + + userAgentOld := req.Header.UserAgent() + if len(userAgentOld) == 0 { + req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...) + } + w := acquirePipelineWork(&c.workPool, timeout) + w.respCopy.Header.disableNormalizing = c.DisableHeaderNamesNormalizing w.req = &w.reqCopy w.resp = &w.respCopy @@ -1542,19 +2443,32 @@ // Request must contain at least non-zero RequestURI with full url (including // scheme and host) or non-zero Host header + RequestURI. // -// Response is ignored if resp is nil. +// The function doesn't follow redirects. Use Get* for following redirects. // -// ErrNoFreeConns is returned if all HostClient.MaxConns connections -// to the host are busy. +// Response is ignored if resp is nil. // // It is recommended obtaining req and resp via AcquireRequest // and AcquireResponse in performance-critical code. func (c *PipelineClient) Do(req *Request, resp *Response) error { + return c.getConnClient().Do(req, resp) +} + +func (c *pipelineConnClient) Do(req *Request, resp *Response) error { c.init() + if c.DisablePathNormalizing { + req.URI().DisablePathNormalizing = true + } + + userAgentOld := req.Header.UserAgent() + if len(userAgentOld) == 0 { + req.Header.userAgent = append(req.Header.userAgent[:0], c.getClientName()...) + } + w := acquirePipelineWork(&c.workPool, 0) w.req = req if resp != nil { + resp.Header.disableNormalizing = c.DisableHeaderNamesNormalizing w.resp = resp } else { w.resp = &w.respCopy @@ -1588,15 +2502,79 @@ return err } -// ErrPipelineOverflow may be returned from PipelineClient.Do +func (c *PipelineClient) getConnClient() *pipelineConnClient { + c.connClientsLock.Lock() + cc := c.getConnClientUnlocked() + c.connClientsLock.Unlock() + return cc +} + +func (c *PipelineClient) getConnClientUnlocked() *pipelineConnClient { + if len(c.connClients) == 0 { + return c.newConnClient() + } + + // Return the client with the minimum number of pending requests. + minCC := c.connClients[0] + minReqs := minCC.PendingRequests() + if minReqs == 0 { + return minCC + } + for i := 1; i < len(c.connClients); i++ { + cc := c.connClients[i] + reqs := cc.PendingRequests() + if reqs == 0 { + return cc + } + if reqs < minReqs { + minCC = cc + minReqs = reqs + } + } + + maxConns := c.MaxConns + if maxConns <= 0 { + maxConns = 1 + } + if len(c.connClients) < maxConns { + return c.newConnClient() + } + return minCC +} + +func (c *PipelineClient) newConnClient() *pipelineConnClient { + cc := &pipelineConnClient{ + Addr: c.Addr, + Name: c.Name, + NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader, + MaxPendingRequests: c.MaxPendingRequests, + MaxBatchDelay: c.MaxBatchDelay, + Dial: c.Dial, + DialDualStack: c.DialDualStack, + DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing, + DisablePathNormalizing: c.DisablePathNormalizing, + IsTLS: c.IsTLS, + TLSConfig: c.TLSConfig, + MaxIdleConnDuration: c.MaxIdleConnDuration, + ReadBufferSize: c.ReadBufferSize, + WriteBufferSize: c.WriteBufferSize, + ReadTimeout: c.ReadTimeout, + WriteTimeout: c.WriteTimeout, + Logger: c.Logger, + } + c.connClients = append(c.connClients, cc) + return cc +} + +// ErrPipelineOverflow may be returned from PipelineClient.Do* // if the requests' queue is overflown. -var ErrPipelineOverflow = errors.New("pipelined requests' queue has been overflown. Increase MaxPendingRequests") +var ErrPipelineOverflow = errors.New("pipelined requests' queue has been overflown. Increase MaxConns and/or MaxPendingRequests") // DefaultMaxPendingRequests is the default value // for PipelineClient.MaxPendingRequests. const DefaultMaxPendingRequests = 1024 -func (c *PipelineClient) init() { +func (c *pipelineConnClient) init() { c.chLock.Lock() if c.chR == nil { maxPendingRequests := c.MaxPendingRequests @@ -1608,27 +2586,36 @@ c.chW = make(chan *pipelineWork, maxPendingRequests) } go func() { - if err := c.worker(); err != nil { - c.logger().Printf("error in PipelineClient(%q): %s", c.Addr, err) - if netErr, ok := err.(net.Error); ok && netErr.Temporary() { - // Throttle client reconnections on temporary errors - time.Sleep(time.Second) + // Keep restarting the worker if it fails (connection errors for example). + for { + if err := c.worker(); err != nil { + c.logger().Printf("error in PipelineClient(%q): %s", c.Addr, err) + if netErr, ok := err.(net.Error); ok && netErr.Temporary() { + // Throttle client reconnections on temporary errors + time.Sleep(time.Second) + } + } else { + c.chLock.Lock() + stop := len(c.chR) == 0 && len(c.chW) == 0 + if !stop { + c.chR = nil + c.chW = nil + } + c.chLock.Unlock() + + if stop { + break + } } } - - c.chLock.Lock() - // Do not reset c.chW to nil, since it may contain - // pending requests, which could be served on the next - // connection to the host. - c.chR = nil - c.chLock.Unlock() }() } c.chLock.Unlock() } -func (c *PipelineClient) worker() error { - conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, c.TLSConfig) +func (c *pipelineConnClient) worker() error { + tlsConfig := c.cachedTLSConfig() + conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout) if err != nil { return err } @@ -1660,14 +2647,30 @@ // Notify pending readers for len(c.chR) > 0 { w := <-c.chR - w.err = errPipelineClientStopped + w.err = errPipelineConnStopped w.done <- struct{}{} } return err } -func (c *PipelineClient) writer(conn net.Conn, stopCh <-chan struct{}) error { +func (c *pipelineConnClient) cachedTLSConfig() *tls.Config { + if !c.IsTLS { + return nil + } + + c.tlsConfigLock.Lock() + cfg := c.tlsConfig + if cfg == nil { + cfg = newClientTLSConfig(c.TLSConfig, c.Addr) + c.tlsConfig = cfg + } + c.tlsConfigLock.Unlock() + + return cfg +} + +func (c *pipelineConnClient) writer(conn net.Conn, stopCh <-chan struct{}) error { writeBufferSize := c.WriteBufferSize if writeBufferSize <= 0 { writeBufferSize = defaultWriteBufferSize @@ -1692,8 +2695,6 @@ w *pipelineWork err error - - lastWriteDeadlineTime time.Time ) close(instantTimerCh) for { @@ -1725,18 +2726,16 @@ continue } + w.resp.parseNetConn(conn) + if writeTimeout > 0 { - // Optimization: update write deadline only if more than 25% - // of the last write deadline exceeded. - // See https://github.com/golang/go/issues/15133 for details. + // Set Deadline every time, since golang has fixed the performance issue + // See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details currentTime := time.Now() - if currentTime.Sub(lastWriteDeadlineTime) > (writeTimeout >> 2) { - if err = conn.SetWriteDeadline(currentTime.Add(writeTimeout)); err != nil { - w.err = err - w.done <- struct{}{} - return err - } - lastWriteDeadlineTime = currentTime + if err = conn.SetWriteDeadline(currentTime.Add(writeTimeout)); err != nil { + w.err = err + w.done <- struct{}{} + return err } } if err = w.req.Write(bw); err != nil { @@ -1762,7 +2761,7 @@ select { case chR <- w: case <-stopCh: - w.err = errPipelineClientStopped + w.err = errPipelineConnStopped w.done <- struct{}{} return nil case <-flushTimerCh: @@ -1778,7 +2777,7 @@ } } -func (c *PipelineClient) reader(conn net.Conn, stopCh <-chan struct{}) error { +func (c *pipelineConnClient) reader(conn net.Conn, stopCh <-chan struct{}) error { readBufferSize := c.ReadBufferSize if readBufferSize <= 0 { readBufferSize = defaultReadBufferSize @@ -1790,8 +2789,6 @@ var ( w *pipelineWork err error - - lastReadDeadlineTime time.Time ) for { select { @@ -1807,17 +2804,13 @@ } if readTimeout > 0 { - // Optimization: update read deadline only if more than 25% - // of the last read deadline exceeded. - // See https://github.com/golang/go/issues/15133 for details. + // Set Deadline every time, since golang has fixed the performance issue + // See https://github.com/golang/go/issues/15133#issuecomment-271571395 for details currentTime := time.Now() - if currentTime.Sub(lastReadDeadlineTime) > (readTimeout >> 2) { - if err = conn.SetReadDeadline(currentTime.Add(readTimeout)); err != nil { - w.err = err - w.done <- struct{}{} - return err - } - lastReadDeadlineTime = currentTime + if err = conn.SetReadDeadline(currentTime.Add(readTimeout)); err != nil { + w.err = err + w.done <- struct{}{} + return err } } if err = w.resp.Read(br); err != nil { @@ -1830,7 +2823,7 @@ } } -func (c *PipelineClient) logger() Logger { +func (c *pipelineConnClient) logger() Logger { if c.Logger != nil { return c.Logger } @@ -1840,10 +2833,23 @@ // PendingRequests returns the current number of pending requests pipelined // to the server. // -// This number may exceed MaxPendingRequests by up to two times, since -// the client may keep up to MaxPendingRequests requests in the queue before -// sending them to the server. +// This number may exceed MaxPendingRequests*MaxConns by up to two times, since +// each connection to the server may keep up to MaxPendingRequests requests +// in the queue before sending them to the server. +// +// This function may be used for balancing load among multiple PipelineClient +// instances. func (c *PipelineClient) PendingRequests() int { + c.connClientsLock.Lock() + n := 0 + for _, cc := range c.connClients { + n += cc.PendingRequests() + } + c.connClientsLock.Unlock() + return n +} + +func (c *pipelineConnClient) PendingRequests() int { c.init() c.chLock.Lock() @@ -1852,7 +2858,22 @@ return n } -var errPipelineClientStopped = errors.New("pipeline client has been stopped") +func (c *pipelineConnClient) getClientName() []byte { + v := c.clientName.Load() + var clientName []byte + if v == nil { + clientName = []byte(c.Name) + if len(clientName) == 0 && !c.NoDefaultUserAgentHeader { + clientName = defaultUserAgent + } + c.clientName.Store(clientName) + } else { + clientName = v.([]byte) + } + return clientName +} + +var errPipelineConnStopped = errors.New("pipeline connection has been stopped") func acquirePipelineWork(pool *sync.Pool, timeout time.Duration) *pipelineWork { v := pool.Get() diff -Nru golang-github-valyala-fasthttp-20160617/client_test.go golang-github-valyala-fasthttp-1.31.0/client_test.go --- golang-github-valyala-fasthttp-20160617/client_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/client_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,11 +1,15 @@ package fasthttp import ( + "bufio" + "bytes" "crypto/tls" "fmt" "io" "net" + "net/url" "os" + "regexp" "runtime" "strings" "sync" @@ -16,31 +20,910 @@ "github.com/valyala/fasthttp/fasthttputil" ) +func TestCloseIdleConnections(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + go func() { + if err := s.Serve(ln); err != nil { + t.Error(err) + } + }() + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + + if _, _, err := c.Get(nil, "http://google.com"); err != nil { + t.Fatal(err) + } + + connsLen := func() int { + c.mLock.Lock() + defer c.mLock.Unlock() + + if _, ok := c.m["google.com"]; !ok { + return 0 + } + + c.m["google.com"].connsLock.Lock() + defer c.m["google.com"].connsLock.Unlock() + + return len(c.m["google.com"].conns) + } + + if conns := connsLen(); conns > 1 { + t.Errorf("expected 1 conns got %d", conns) + } + + c.CloseIdleConnections() + + if conns := connsLen(); conns > 0 { + t.Errorf("expected 0 conns got %d", conns) + } +} + +func TestPipelineClientSetUserAgent(t *testing.T) { + t.Parallel() + + testPipelineClientSetUserAgent(t, 0) +} + +func TestPipelineClientSetUserAgentTimeout(t *testing.T) { + t.Parallel() + + testPipelineClientSetUserAgent(t, time.Second) +} + +func testPipelineClientSetUserAgent(t *testing.T, timeout time.Duration) { + ln := fasthttputil.NewInmemoryListener() + + userAgentSeen := "" + s := &Server{ + Handler: func(ctx *RequestCtx) { + userAgentSeen = string(ctx.UserAgent()) + }, + } + go s.Serve(ln) //nolint:errcheck + + userAgent := "I'm not fasthttp" + c := &HostClient{ + Name: userAgent, + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req := AcquireRequest() + res := AcquireResponse() + + req.SetRequestURI("http://example.com") + + var err error + if timeout <= 0 { + err = c.Do(req, res) + } else { + err = c.DoTimeout(req, res, timeout) + } + + if err != nil { + t.Fatal(err) + } + if userAgentSeen != userAgent { + t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent) + } +} + +func TestPipelineClientIssue832(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + req := AcquireRequest() + // Don't defer ReleaseRequest as we use it in a goroutine that might not be done at the end. + + req.SetHost("example.com") + + res := AcquireResponse() + // Don't defer ReleaseResponse as we use it in a goroutine that might not be done at the end. + + client := PipelineClient{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + ReadTimeout: time.Millisecond * 10, + Logger: &testLogger{}, // Ignore log output. + } + + attempts := 10 + go func() { + for i := 0; i < attempts; i++ { + c, err := ln.Accept() + if err != nil { + t.Error(err) + } + if c != nil { + go func() { + time.Sleep(time.Millisecond * 50) + c.Close() + }() + } + } + }() + + done := make(chan int) + go func() { + defer close(done) + + for i := 0; i < attempts; i++ { + if err := client.Do(req, res); err == nil { + t.Error("error expected") + } + } + }() + + select { + case <-time.After(time.Second * 2): + t.Fatal("PipelineClient did not restart worker") + case <-done: + } +} + +func TestClientInvalidURI(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + requests := int64(0) + s := &Server{ + Handler: func(ctx *RequestCtx) { + atomic.AddInt64(&requests, 1) + }, + } + go s.Serve(ln) //nolint:errcheck + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req, res := AcquireRequest(), AcquireResponse() + defer func() { + ReleaseRequest(req) + ReleaseResponse(res) + }() + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://example.com\r\n\r\nGET /\r\n\r\n") + err := c.Do(req, res) + if err == nil { + t.Fatal("expected error (missing required Host header in request)") + } + if n := atomic.LoadInt64(&requests); n != 0 { + t.Fatalf("0 requests expected, got %d", n) + } +} + +func TestClientGetWithBody(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + body := ctx.Request.Body() + ctx.Write(body) //nolint:errcheck + }, + } + go s.Serve(ln) //nolint:errcheck + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req, res := AcquireRequest(), AcquireResponse() + defer func() { + ReleaseRequest(req) + ReleaseResponse(res) + }() + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://example.com") + req.SetBodyString("test") + err := c.Do(req, res) + if err != nil { + t.Fatal(err) + } + if len(res.Body()) == 0 { + t.Fatal("missing request body") + } +} + +func TestClientURLAuth(t *testing.T) { + t.Parallel() + + cases := map[string]string{ + "user:pass@": "Basic dXNlcjpwYXNz", + "foo:@": "Basic Zm9vOg==", + ":@": "", + "@": "", + "": "", + } + + ch := make(chan string, 1) + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + ch <- string(ctx.Request.Header.Peek(HeaderAuthorization)) + }, + } + go s.Serve(ln) //nolint:errcheck + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + for up, expected := range cases { + req := AcquireRequest() + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://" + up + "example.com/foo/bar") + if err := c.Do(req, nil); err != nil { + t.Fatal(err) + } + + val := <-ch + + if val != expected { + t.Fatalf("wrong %s header: %s expected %s", HeaderAuthorization, val, expected) + } + } +} + +func TestClientNilResp(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + go s.Serve(ln) //nolint:errcheck + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req := AcquireRequest() + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://example.com") + if err := c.Do(req, nil); err != nil { + t.Fatal(err) + } + if err := c.DoTimeout(req, nil, time.Second); err != nil { + t.Fatal(err) + } + ln.Close() +} + +func TestPipelineClientNilResp(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + go s.Serve(ln) //nolint:errcheck + c := &PipelineClient{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req := AcquireRequest() + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://example.com") + if err := c.Do(req, nil); err != nil { + t.Fatal(err) + } + if err := c.DoTimeout(req, nil, time.Second); err != nil { + t.Fatal(err) + } + if err := c.DoDeadline(req, nil, time.Now().Add(time.Second)); err != nil { + t.Fatal(err) + } +} + +func TestClientParseConn(t *testing.T) { + t.Parallel() + + network := "tcp" + ln, _ := net.Listen(network, "127.0.0.1:0") + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + go s.Serve(ln) //nolint:errcheck + host := ln.Addr().String() + c := &Client{} + req, res := AcquireRequest(), AcquireResponse() + defer func() { + ReleaseRequest(req) + ReleaseResponse(res) + }() + req.SetRequestURI("http://" + host + "") + if err := c.Do(req, res); err != nil { + t.Fatal(err) + } + + if res.RemoteAddr().Network() != network { + t.Fatalf("req RemoteAddr parse network fail: %s, hope: %s", res.RemoteAddr().Network(), network) + } + if host != res.RemoteAddr().String() { + t.Fatalf("req RemoteAddr parse addr fail: %s, hope: %s", res.RemoteAddr().String(), host) + } + + if !regexp.MustCompile(`^127\.0\.0\.1:[0-9]{4,5}$`).MatchString(res.LocalAddr().String()) { + t.Fatalf("res LocalAddr addr match fail: %s, hope match: %s", res.LocalAddr().String(), "^127.0.0.1:[0-9]{4,5}$") + } +} + +func TestClientPostArgs(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + body := ctx.Request.Body() + if len(body) == 0 { + return + } + ctx.Write(body) //nolint:errcheck + }, + } + go s.Serve(ln) //nolint:errcheck + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req, res := AcquireRequest(), AcquireResponse() + defer func() { + ReleaseRequest(req) + ReleaseResponse(res) + }() + args := req.PostArgs() + args.Add("addhttp2", "support") + args.Add("fast", "http") + req.Header.SetMethod(MethodPost) + req.SetRequestURI("http://make.fasthttp.great?again") + err := c.Do(req, res) + if err != nil { + t.Fatal(err) + } + if len(res.Body()) == 0 { + t.Fatal("cannot set args as body") + } +} + +func TestClientRedirectSameSchema(t *testing.T) { + t.Parallel() + + listenHTTPS1 := testClientRedirectListener(t, true) + defer listenHTTPS1.Close() + + listenHTTPS2 := testClientRedirectListener(t, true) + defer listenHTTPS2.Close() + + sHTTPS1 := testClientRedirectChangingSchemaServer(t, listenHTTPS1, listenHTTPS1, true) + defer sHTTPS1.Stop() + + sHTTPS2 := testClientRedirectChangingSchemaServer(t, listenHTTPS2, listenHTTPS2, false) + defer sHTTPS2.Stop() + + destURL := fmt.Sprintf("https://%s/baz", listenHTTPS1.Addr().String()) + + urlParsed, err := url.Parse(destURL) + if err != nil { + t.Fatal(err) + return + } + + reqClient := &HostClient{ + IsTLS: true, + Addr: urlParsed.Host, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + + statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond) + if err != nil { + t.Fatalf("HostClient error: %s", err) + return + } + + if statusCode != 200 { + t.Fatalf("HostClient error code response %d", statusCode) + return + } +} + +func TestClientRedirectClientChangingSchemaHttp2Https(t *testing.T) { + t.Parallel() + + listenHTTPS := testClientRedirectListener(t, true) + defer listenHTTPS.Close() + + listenHTTP := testClientRedirectListener(t, false) + defer listenHTTP.Close() + + sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true) + defer sHTTPS.Stop() + + sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false) + defer sHTTP.Stop() + + destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String()) + + reqClient := &Client{ + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + + statusCode, _, err := reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond) + if err != nil { + t.Fatalf("HostClient error: %s", err) + return + } + + if statusCode != 200 { + t.Fatalf("HostClient error code response %d", statusCode) + return + } +} + +func TestClientRedirectHostClientChangingSchemaHttp2Https(t *testing.T) { + t.Parallel() + + listenHTTPS := testClientRedirectListener(t, true) + defer listenHTTPS.Close() + + listenHTTP := testClientRedirectListener(t, false) + defer listenHTTP.Close() + + sHTTPS := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, true) + defer sHTTPS.Stop() + + sHTTP := testClientRedirectChangingSchemaServer(t, listenHTTPS, listenHTTP, false) + defer sHTTP.Stop() + + destURL := fmt.Sprintf("http://%s/baz", listenHTTP.Addr().String()) + + urlParsed, err := url.Parse(destURL) + if err != nil { + t.Fatal(err) + return + } + + reqClient := &HostClient{ + Addr: urlParsed.Host, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + + _, _, err = reqClient.GetTimeout(nil, destURL, 4000*time.Millisecond) + if err != ErrHostClientRedirectToDifferentScheme { + t.Fatal("expected HostClient error") + } +} + +func testClientRedirectListener(t *testing.T, isTLS bool) net.Listener { + var ln net.Listener + var err error + var tlsConfig *tls.Config + + if isTLS { + certData, keyData, kerr := GenerateTestCertificate("localhost") + if kerr != nil { + t.Fatal(kerr) + } + + cert, kerr := tls.X509KeyPair(certData, keyData) + if kerr != nil { + t.Fatal(kerr) + } + + tlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + ln, err = tls.Listen("tcp", "localhost:0", tlsConfig) + } else { + ln, err = net.Listen("tcp", "localhost:0") + } + + if err != nil { + t.Fatalf("cannot listen isTLS %v: %s", isTLS, err) + } + + return ln +} + +func testClientRedirectChangingSchemaServer(t *testing.T, https, http net.Listener, isTLS bool) *testEchoServer { + s := &Server{ + Handler: func(ctx *RequestCtx) { + if ctx.IsTLS() { + ctx.SetStatusCode(200) + } else { + ctx.Redirect(fmt.Sprintf("https://%s/baz", https.Addr().String()), 301) + } + }, + } + + var ln net.Listener + if isTLS { + ln = https + } else { + ln = http + } + + ch := make(chan struct{}) + go func() { + err := s.Serve(ln) + if err != nil { + t.Errorf("unexpected error returned from Serve(): %s", err) + } + close(ch) + }() + return &testEchoServer{ + s: s, + ln: ln, + ch: ch, + t: t, + } +} + +func TestClientHeaderCase(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + c, err := ln.Accept() + if err != nil { + t.Error(err) + } + c.Write([]byte("HTTP/1.1 200 OK\r\n" + //nolint:errcheck + "content-type: text/plain\r\n" + + "transfer-encoding: chunked\r\n\r\n" + + "24\r\nThis is the data in the first chunk \r\n" + + "1B\r\nand this is the second one \r\n" + + "0\r\n\r\n", + )) + }() + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + ReadTimeout: time.Millisecond * 10, + + // Even without name normalizing we should parse headers correctly. + DisableHeaderNamesNormalizing: true, + } + + code, body, err := c.Get(nil, "http://example.com") + if err != nil { + t.Error(err) + } else if code != 200 { + t.Errorf("expected status code 200 got %d", code) + } else if string(body) != "This is the data in the first chunk and this is the second one " { + t.Errorf("wrong body: %q", body) + } +} + +func TestClientReadTimeout(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + timeout := false + s := &Server{ + Handler: func(ctx *RequestCtx) { + if timeout { + time.Sleep(time.Second) + } else { + timeout = true + } + }, + Logger: &testLogger{}, // Don't print closed pipe errors. + } + go s.Serve(ln) //nolint:errcheck + + c := &HostClient{ + ReadTimeout: time.Millisecond * 400, + MaxIdemponentCallAttempts: 1, + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + + req := AcquireRequest() + res := AcquireResponse() + + req.SetRequestURI("http://localhost") + + // Setting Connection: Close will make the connection be + // returned to the pool. + req.SetConnectionClose() + + if err := c.Do(req, res); err != nil { + t.Fatal(err) + } + + ReleaseRequest(req) + ReleaseResponse(res) + + done := make(chan struct{}) + go func() { + req := AcquireRequest() + res := AcquireResponse() + + req.SetRequestURI("http://localhost") + req.SetConnectionClose() + + if err := c.Do(req, res); err != ErrTimeout { + t.Errorf("expected ErrTimeout got %#v", err) + } + + ReleaseRequest(req) + ReleaseResponse(res) + close(done) + }() + + select { + case <-done: + // This shouldn't take longer than the timeout times the number of requests it is going to try to do. + // Give it an extra second just to be sure. + case <-time.After(c.ReadTimeout*time.Duration(c.MaxIdemponentCallAttempts) + time.Second): + t.Fatal("Client.ReadTimeout didn't work") + } +} + +func TestClientDefaultUserAgent(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + userAgentSeen := "" + s := &Server{ + Handler: func(ctx *RequestCtx) { + userAgentSeen = string(ctx.UserAgent()) + }, + } + go s.Serve(ln) //nolint:errcheck + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req := AcquireRequest() + res := AcquireResponse() + + req.SetRequestURI("http://example.com") + + err := c.Do(req, res) + if err != nil { + t.Fatal(err) + } + if userAgentSeen != string(defaultUserAgent) { + t.Fatalf("User-Agent defers %q != %q", userAgentSeen, defaultUserAgent) + } +} + +func TestClientSetUserAgent(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + userAgentSeen := "" + s := &Server{ + Handler: func(ctx *RequestCtx) { + userAgentSeen = string(ctx.UserAgent()) + }, + } + go s.Serve(ln) //nolint:errcheck + + userAgent := "I'm not fasthttp" + c := &Client{ + Name: userAgent, + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req := AcquireRequest() + res := AcquireResponse() + + req.SetRequestURI("http://example.com") + + err := c.Do(req, res) + if err != nil { + t.Fatal(err) + } + if userAgentSeen != userAgent { + t.Fatalf("User-Agent defers %q != %q", userAgentSeen, userAgent) + } +} + +func TestClientNoUserAgent(t *testing.T) { + ln := fasthttputil.NewInmemoryListener() + + userAgentSeen := "" + s := &Server{ + Handler: func(ctx *RequestCtx) { + userAgentSeen = string(ctx.UserAgent()) + }, + } + go s.Serve(ln) //nolint:errcheck + + c := &Client{ + NoDefaultUserAgentHeader: true, + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + req := AcquireRequest() + res := AcquireResponse() + + req.SetRequestURI("http://example.com") + + err := c.Do(req, res) + if err != nil { + t.Fatal(err) + } + if userAgentSeen != "" { + t.Fatalf("User-Agent wrong %q != %q", userAgentSeen, "") + } +} + +func TestClientDoWithCustomHeaders(t *testing.T) { + t.Parallel() + + // make sure that the client sends all the request headers and body. + ln := fasthttputil.NewInmemoryListener() + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + + uri := "/foo/bar/baz?a=b&cd=12" + headers := map[string]string{ + "Foo": "bar", + "Host": "xxx.com", + "Content-Type": "asdfsdf", + "a-b-c-d-f": "", + } + body := "request body" + + ch := make(chan error) + go func() { + conn, err := ln.Accept() + if err != nil { + ch <- fmt.Errorf("cannot accept client connection: %s", err) + return + } + br := bufio.NewReader(conn) + + var req Request + if err = req.Read(br); err != nil { + ch <- fmt.Errorf("cannot read client request: %s", err) + return + } + if string(req.Header.Method()) != MethodPost { + ch <- fmt.Errorf("unexpected request method: %q. Expecting %q", req.Header.Method(), MethodPost) + return + } + reqURI := req.RequestURI() + if string(reqURI) != uri { + ch <- fmt.Errorf("unexpected request uri: %q. Expecting %q", reqURI, uri) + return + } + for k, v := range headers { + hv := req.Header.Peek(k) + if string(hv) != v { + ch <- fmt.Errorf("unexpected value for header %q: %q. Expecting %q", k, hv, v) + return + } + } + cl := req.Header.ContentLength() + if cl != len(body) { + ch <- fmt.Errorf("unexpected content-length %d. Expecting %d", cl, len(body)) + return + } + reqBody := req.Body() + if string(reqBody) != body { + ch <- fmt.Errorf("unexpected request body: %q. Expecting %q", reqBody, body) + return + } + + var resp Response + bw := bufio.NewWriter(conn) + if err = resp.Write(bw); err != nil { + ch <- fmt.Errorf("cannot send response: %s", err) + return + } + if err = bw.Flush(); err != nil { + ch <- fmt.Errorf("cannot flush response: %s", err) + return + } + + ch <- nil + }() + + var req Request + req.Header.SetMethod(MethodPost) + req.SetRequestURI(uri) + for k, v := range headers { + req.Header.Set(k, v) + } + req.SetBodyString(body) + + var resp Response + + err := c.DoTimeout(&req, &resp, time.Second) + if err != nil { + t.Fatalf("error when doing request: %s", err) + } + + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("timeout") + } +} + func TestPipelineClientDoSerial(t *testing.T) { - testPipelineClientDoConcurrent(t, 1, 0) + t.Parallel() + + testPipelineClientDoConcurrent(t, 1, 0, 0) } func TestPipelineClientDoConcurrent(t *testing.T) { - testPipelineClientDoConcurrent(t, 10, 0) + t.Parallel() + + testPipelineClientDoConcurrent(t, 10, 0, 1) } func TestPipelineClientDoBatchDelayConcurrent(t *testing.T) { - testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond) + t.Parallel() + + testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 1) +} + +func TestPipelineClientDoBatchDelayConcurrentMultiConn(t *testing.T) { + t.Parallel() + + testPipelineClientDoConcurrent(t, 10, 5*time.Millisecond, 3) } -func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay time.Duration) { +func testPipelineClientDoConcurrent(t *testing.T, concurrency int, maxBatchDelay time.Duration, maxConns int) { ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) { - ctx.WriteString("OK") + ctx.WriteString("OK") //nolint:errcheck }, } serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } close(serverStopCh) }() @@ -49,10 +932,10 @@ Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, - MaxIdleConnDuration: 23 * time.Millisecond, - MaxPendingRequests: 6, - MaxBatchDelay: maxBatchDelay, - Logger: &customLogger{}, + MaxConns: maxConns, + MaxPendingRequests: concurrency, + MaxBatchDelay: maxBatchDelay, + Logger: &testLogger{}, } clientStopCh := make(chan struct{}, concurrency) @@ -120,47 +1003,347 @@ ReleaseResponse(resp) } -func TestClientDoTimeoutDisableNormalizing(t *testing.T) { +func TestPipelineClientDoDisableHeaderNamesNormalizing(t *testing.T) { + t.Parallel() + + testPipelineClientDisableHeaderNamesNormalizing(t, 0) +} + +func TestPipelineClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) { + t.Parallel() + + testPipelineClientDisableHeaderNamesNormalizing(t, time.Second) +} + +func testPipelineClientDisableHeaderNamesNormalizing(t *testing.T, timeout time.Duration) { + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Response.Header.Set("foo-BAR", "baz") + }, + DisableHeaderNamesNormalizing: true, + } + + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + c := &PipelineClient{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + DisableHeaderNamesNormalizing: true, + } + + var req Request + req.SetRequestURI("http://aaaai.com/bsdf?sddfsd") + var resp Response + for i := 0; i < 5; i++ { + if timeout > 0 { + if err := c.DoTimeout(&req, &resp, timeout); err != nil { + t.Fatalf("unexpected error: %s", err) + } + } else { + if err := c.Do(&req, &resp); err != nil { + t.Fatalf("unexpected error: %s", err) + } + } + hv := resp.Header.Peek("foo-BAR") + if string(hv) != "baz" { + t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz") + } + hv = resp.Header.Peek("Foo-Bar") + if len(hv) > 0 { + t.Fatalf("unexpected non-empty header value %q", hv) + } + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + +func TestClientDoTimeoutDisableHeaderNamesNormalizing(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Response.Header.Set("foo-BAR", "baz") + }, + DisableHeaderNamesNormalizing: true, + } + + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + DisableHeaderNamesNormalizing: true, + } + + var req Request + req.SetRequestURI("http://aaaai.com/bsdf?sddfsd") + var resp Response + for i := 0; i < 5; i++ { + if err := c.DoTimeout(&req, &resp, time.Second); err != nil { + t.Fatalf("unexpected error: %s", err) + } + hv := resp.Header.Peek("foo-BAR") + if string(hv) != "baz" { + t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz") + } + hv = resp.Header.Peek("Foo-Bar") + if len(hv) > 0 { + t.Fatalf("unexpected non-empty header value %q", hv) + } + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + +func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + uri := ctx.URI() + uri.DisablePathNormalizing = true + ctx.Response.Header.Set("received-uri", string(uri.FullURI())) + }, + } + + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + DisablePathNormalizing: true, + } + + urlWithEncodedPath := "http://example.com/encoded/Y%2BY%2FY%3D/stuff" + + var req Request + req.SetRequestURI(urlWithEncodedPath) + var resp Response + for i := 0; i < 5; i++ { + if err := c.DoTimeout(&req, &resp, time.Second); err != nil { + t.Fatalf("unexpected error: %s", err) + } + hv := resp.Header.Peek("received-uri") + if string(hv) != urlWithEncodedPath { + t.Fatalf("request uri was normalized: %q. Expecting %q", hv, urlWithEncodedPath) + } + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + +func TestHostClientPendingRequests(t *testing.T) { + t.Parallel() + + const concurrency = 10 + doneCh := make(chan struct{}) + readyCh := make(chan struct{}, concurrency) + s := &Server{ + Handler: func(ctx *RequestCtx) { + readyCh <- struct{}{} + <-doneCh + }, + } ln := fasthttputil.NewInmemoryListener() + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + c := &HostClient{ + Addr: "foobar", + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + + pendingRequests := c.PendingRequests() + if pendingRequests != 0 { + t.Fatalf("non-zero pendingRequests: %d", pendingRequests) + } + + resultCh := make(chan error, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + req := AcquireRequest() + req.SetRequestURI("http://foobar/baz") + resp := AcquireResponse() + + if err := c.DoTimeout(req, resp, 10*time.Second); err != nil { + resultCh <- fmt.Errorf("unexpected error: %s", err) + return + } + + if resp.StatusCode() != StatusOK { + resultCh <- fmt.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK) + return + } + resultCh <- nil + }() + } + + // wait while all the requests reach server + for i := 0; i < concurrency; i++ { + select { + case <-readyCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + } + + pendingRequests = c.PendingRequests() + if pendingRequests != concurrency { + t.Fatalf("unexpected pendingRequests: %d. Expecting %d", pendingRequests, concurrency) + } + + // unblock request handlers on the server and wait until all the requests are finished. + close(doneCh) + for i := 0; i < concurrency; i++ { + select { + case err := <-resultCh: + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + case <-time.After(time.Second): + t.Fatalf("timeout") + } + } + + pendingRequests = c.PendingRequests() + if pendingRequests != 0 { + t.Fatalf("non-zero pendingRequests: %d", pendingRequests) + } + + // stop the server + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + +func TestHostClientMaxConnsWithDeadline(t *testing.T) { + t.Parallel() + + var ( + emptyBodyCount uint8 + ln = fasthttputil.NewInmemoryListener() + timeout = 200 * time.Millisecond + wg sync.WaitGroup + ) s := &Server{ Handler: func(ctx *RequestCtx) { - ctx.Response.Header.Set("foo-BAR", "baz") + if len(ctx.PostBody()) == 0 { + emptyBodyCount++ + } + + ctx.WriteString("foo") //nolint:errcheck }, - DisableHeaderNamesNormalizing: true, } - serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } close(serverStopCh) }() - c := &Client{ + c := &HostClient{ + Addr: "foobar", Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, - DisableHeaderNamesNormalizing: true, + MaxConns: 1, } - var req Request - req.SetRequestURI("http://aaaai.com/bsdf?sddfsd") - var resp Response for i := 0; i < 5; i++ { - if err := c.DoTimeout(&req, &resp, time.Second); err != nil { - t.Fatalf("unexpected error: %s", err) - } - hv := resp.Header.Peek("foo-BAR") - if string(hv) != "baz" { - t.Fatalf("unexpected header value: %q. Expecting %q", hv, "baz") - } - hv = resp.Header.Peek("Foo-Bar") - if len(hv) > 0 { - t.Fatalf("unexpected non-empty header value %q", hv) - } + wg.Add(1) + go func() { + defer wg.Done() + + req := AcquireRequest() + req.SetRequestURI("http://foobar/baz") + req.Header.SetMethod(MethodPost) + req.SetBodyString("bar") + resp := AcquireResponse() + + for { + if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil { + if err == ErrNoFreeConns { + time.Sleep(time.Millisecond) + continue + } + t.Errorf("unexpected error: %s", err) + } + break + } + + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + body := resp.Body() + if string(body) != "foo" { + t.Errorf("unexpected body %q. Expecting %q", body, "abcd") + } + }() } + wg.Wait() if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) @@ -170,15 +1353,21 @@ case <-time.After(time.Second): t.Fatalf("timeout") } + + if emptyBodyCount > 0 { + t.Fatalf("at least one request body was empty") + } } func TestHostClientMaxConnDuration(t *testing.T) { + t.Parallel() + ln := fasthttputil.NewInmemoryListener() connectionCloseCount := uint32(0) s := &Server{ Handler: func(ctx *RequestCtx) { - ctx.WriteString("abcd") + ctx.WriteString("abcd") //nolint:errcheck if ctx.Request.ConnectionClose() { atomic.AddUint32(&connectionCloseCount, 1) } @@ -187,7 +1376,7 @@ serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } close(serverStopCh) }() @@ -229,18 +1418,20 @@ } func TestHostClientMultipleAddrs(t *testing.T) { + t.Parallel() + ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) { - ctx.Write(ctx.Host()) + ctx.Write(ctx.Host()) //nolint:errcheck ctx.SetConnectionClose() }, } serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } close(serverStopCh) }() @@ -287,7 +1478,8 @@ } func TestClientFollowRedirects(t *testing.T) { - addr := "127.0.0.1:55234" + t.Parallel() + s := &Server{ Handler: func(ctx *RequestCtx) { switch string(ctx.Path()) { @@ -304,22 +1496,25 @@ } }, } - ln, err := net.Listen("tcp4", addr) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } + ln := fasthttputil.NewInmemoryListener() serverStopCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } close(serverStopCh) }() - uri := fmt.Sprintf("http://%s/foo", addr) + c := &HostClient{ + Addr: "xxx", + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + } + for i := 0; i < 10; i++ { - statusCode, body, err := GetTimeout(nil, uri, time.Second) + statusCode, body, err := c.GetTimeout(nil, "http://xxx/foo", time.Second) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -331,9 +1526,8 @@ } } - uri = fmt.Sprintf("http://%s/aaab/sss", addr) for i := 0; i < 10; i++ { - statusCode, body, err := Get(nil, uri) + statusCode, body, err := c.Get(nil, "http://xxx/aaab/sss") if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -344,61 +1538,99 @@ t.Fatalf("unexpected response %q. Expecting %q", body, "/aaab/sss") } } + + for i := 0; i < 10; i++ { + req := AcquireRequest() + resp := AcquireResponse() + + req.SetRequestURI("http://xxx/foo") + + err := c.DoRedirects(req, resp, 16) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if statusCode := resp.StatusCode(); statusCode != StatusOK { + t.Fatalf("unexpected status code: %d", statusCode) + } + + if body := string(resp.Body()); body != "/bar" { + t.Fatalf("unexpected response %q. Expecting %q", body, "/bar") + } + + ReleaseRequest(req) + ReleaseResponse(resp) + } + + req := AcquireRequest() + resp := AcquireResponse() + + req.SetRequestURI("http://xxx/foo") + + err := c.DoRedirects(req, resp, 0) + if have, want := err, ErrTooManyRedirects; have != want { + t.Fatalf("want error: %v, have %v", want, have) + } + + ReleaseRequest(req) + ReleaseResponse(resp) } func TestClientGetTimeoutSuccess(t *testing.T) { - addr := "127.0.0.1:56889" - s := startEchoServer(t, "tcp", addr) + t.Parallel() + + s := startEchoServer(t, "tcp", "127.0.0.1:") defer s.Stop() - addr = "http://" + addr - testClientGetTimeoutSuccess(t, &defaultClient, addr, 100) + testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100) } func TestClientGetTimeoutSuccessConcurrent(t *testing.T) { - addr := "127.0.0.1:56989" - s := startEchoServer(t, "tcp", addr) + t.Parallel() + + s := startEchoServer(t, "tcp", "127.0.0.1:") defer s.Stop() - addr = "http://" + addr var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() - testClientGetTimeoutSuccess(t, &defaultClient, addr, 100) + testClientGetTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100) }() } wg.Wait() } func TestClientDoTimeoutSuccess(t *testing.T) { - addr := "127.0.0.1:63897" - s := startEchoServer(t, "tcp", addr) + t.Parallel() + + s := startEchoServer(t, "tcp", "127.0.0.1:") defer s.Stop() - addr = "http://" + addr - testClientDoTimeoutSuccess(t, &defaultClient, addr, 100) + testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100) } func TestClientDoTimeoutSuccessConcurrent(t *testing.T) { - addr := "127.0.0.1:63898" - s := startEchoServer(t, "tcp", addr) + t.Parallel() + + s := startEchoServer(t, "tcp", "127.0.0.1:") defer s.Stop() - addr = "http://" + addr var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) go func() { defer wg.Done() - testClientDoTimeoutSuccess(t, &defaultClient, addr, 100) + testClientDoTimeoutSuccess(t, &defaultClient, "http://"+s.Addr(), 100) }() } wg.Wait() } func TestClientGetTimeoutError(t *testing.T) { + t.Parallel() + c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil @@ -409,6 +1641,8 @@ } func TestClientGetTimeoutErrorConcurrent(t *testing.T) { + t.Parallel() + c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil @@ -428,6 +1662,8 @@ } func TestClientDoTimeoutError(t *testing.T) { + t.Parallel() + c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil @@ -438,6 +1674,8 @@ } func TestClientDoTimeoutErrorConcurrent(t *testing.T) { + t.Parallel() + c := &Client{ Dial: func(addr string) (net.Conn, error) { return &readTimeoutConn{t: time.Second}, nil @@ -508,16 +1746,118 @@ return nil } -func TestClientIdempotentRequest(t *testing.T) { +func (r *readTimeoutConn) LocalAddr() net.Addr { + return nil +} + +func (r *readTimeoutConn) RemoteAddr() net.Addr { + return nil +} + +func TestClientNonIdempotentRetry(t *testing.T) { + t.Parallel() + + dialsCount := 0 + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + dialsCount++ + switch dialsCount { + case 1, 2: + return &readErrorConn{}, nil + case 3: + return &singleReadConn{ + s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456", + }, nil + default: + t.Fatalf("unexpected number of dials: %d", dialsCount) + } + panic("unreachable") + }, + } + + // This POST must succeed, since the readErrorConn closes + // the connection before sending any response. + // So the client must retry non-idempotent request. + dialsCount = 0 + statusCode, body, err := c.Post(nil, "http://foobar/a/b", nil) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if statusCode != 345 { + t.Fatalf("unexpected status code: %d. Expecting 345", statusCode) + } + if string(body) != "0123456" { + t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456") + } + + // Verify that idempotent GET succeeds. + dialsCount = 0 + statusCode, body, err = c.Get(nil, "http://foobar/a/b") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if statusCode != 345 { + t.Fatalf("unexpected status code: %d. Expecting 345", statusCode) + } + if string(body) != "0123456" { + t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456") + } +} + +func TestClientNonIdempotentRetry_BodyStream(t *testing.T) { + t.Parallel() + dialsCount := 0 c := &Client{ Dial: func(addr string) (net.Conn, error) { + dialsCount++ switch dialsCount { - case 0: - dialsCount++ + case 1, 2: return &readErrorConn{}, nil + case 3: + return &singleEchoConn{ + b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"), + }, nil + default: + t.Fatalf("unexpected number of dials: %d", dialsCount) + } + panic("unreachable") + }, + } + + dialsCount = 0 + + req := Request{} + res := Response{} + + req.SetRequestURI("http://foobar/a/b") + req.Header.SetMethod("POST") + body := bytes.NewBufferString("test") + req.SetBodyStream(body, body.Len()) + + err := c.Do(&req, &res) + if err == nil { + t.Fatal("expected error from being unable to retry a bodyStream") + } +} + +func TestClientIdempotentRequest(t *testing.T) { + t.Parallel() + + dialsCount := 0 + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + dialsCount++ + switch dialsCount { case 1: - dialsCount++ + return &singleReadConn{ + s: "invalid response", + }, nil + case 2: + return &writeErrorConn{}, nil + case 3: + return &readErrorConn{}, nil + case 4: return &singleReadConn{ s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456", }, nil @@ -528,6 +1868,7 @@ }, } + // idempotent GET must succeed. statusCode, body, err := c.Get(nil, "http://foobar/a/b") if err != nil { t.Fatalf("unexpected error: %s", err) @@ -539,19 +1880,157 @@ t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456") } - var args Args + var args Args + + // non-idempotent POST must fail on incorrect singleReadConn + dialsCount = 0 + _, _, err = c.Post(nil, "http://foobar/a/b", &args) + if err == nil { + t.Fatalf("expecting error") + } + + // non-idempotent POST must fail on incorrect singleReadConn + dialsCount = 0 + _, _, err = c.Post(nil, "http://foobar/a/b", nil) + if err == nil { + t.Fatalf("expecting error") + } +} + +func TestClientRetryRequestWithCustomDecider(t *testing.T) { + t.Parallel() + + dialsCount := 0 + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + dialsCount++ + switch dialsCount { + case 1: + return &singleReadConn{ + s: "invalid response", + }, nil + case 2: + return &writeErrorConn{}, nil + case 3: + return &readErrorConn{}, nil + case 4: + return &singleReadConn{ + s: "HTTP/1.1 345 OK\r\nContent-Type: foobar\r\nContent-Length: 7\r\n\r\n0123456", + }, nil + default: + t.Fatalf("unexpected number of dials: %d", dialsCount) + } + panic("unreachable") + }, + RetryIf: func(req *Request) bool { + return req.URI().String() == "http://foobar/a/b" + }, + } + + var args Args + + // Post must succeed for http://foobar/a/b uri. + statusCode, body, err := c.Post(nil, "http://foobar/a/b", &args) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if statusCode != 345 { + t.Fatalf("unexpected status code: %d. Expecting 345", statusCode) + } + if string(body) != "0123456" { + t.Fatalf("unexpected body: %q. Expecting %q", body, "0123456") + } + + // POST must fail for http://foobar/a/b/c uri. + dialsCount = 0 + _, _, err = c.Post(nil, "http://foobar/a/b/c", &args) + if err == nil { + t.Fatalf("expecting error") + } +} + +func TestHostClientTransport(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.WriteString("abcd") //nolint:errcheck + }, + } + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + c := &HostClient{ + Addr: "foobar", + Transport: func() TransportFunc { + c, _ := ln.Dial() + + br := bufio.NewReader(c) + bw := bufio.NewWriter(c) + + return func(req *Request, res *Response) error { + if err := req.Write(bw); err != nil { + return err + } + + if err := bw.Flush(); err != nil { + return err + } + + return res.Read(br) + } + }(), + } + + for i := 0; i < 5; i++ { + statusCode, body, err := c.Get(nil, "http://aaaa.com/bbb/cc") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if statusCode != StatusOK { + t.Fatalf("unexpected status code %d. Expecting %d", statusCode, StatusOK) + } + if string(body) != "abcd" { + t.Fatalf("unexpected body %q. Expecting %q", body, "abcd") + } + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } +} + +type writeErrorConn struct { + net.Conn +} + +func (w *writeErrorConn) Write(p []byte) (int, error) { + return 1, fmt.Errorf("error") +} - dialsCount = 0 - statusCode, body, err = c.Post(nil, "http://foobar/a/b", &args) - if err == nil { - t.Fatalf("expecting error") - } +func (w *writeErrorConn) Close() error { + return nil +} - dialsCount = 0 - statusCode, body, err = c.Post(nil, "http://foobar/a/b", nil) - if err == nil { - t.Fatalf("expecting error") - } +func (w *writeErrorConn) LocalAddr() net.Addr { + return nil +} + +func (w *writeErrorConn) RemoteAddr() net.Addr { + return nil } type readErrorConn struct { @@ -570,6 +2049,14 @@ return nil } +func (r *readErrorConn) LocalAddr() net.Addr { + return nil +} + +func (r *readErrorConn) RemoteAddr() net.Addr { + return nil +} + type singleReadConn struct { net.Conn s string @@ -593,38 +2080,134 @@ return nil } +func (r *singleReadConn) LocalAddr() net.Addr { + return nil +} + +func (r *singleReadConn) RemoteAddr() net.Addr { + return nil +} + +type singleEchoConn struct { + net.Conn + b []byte + n int +} + +func (r *singleEchoConn) Read(p []byte) (int, error) { + if len(r.b) == r.n { + return 0, io.EOF + } + n := copy(p, r.b[r.n:]) + r.n += n + return n, nil +} + +func (r *singleEchoConn) Write(p []byte) (int, error) { + r.b = append(r.b, p...) + return len(p), nil +} + +func (r *singleEchoConn) Close() error { + return nil +} + +func (r *singleEchoConn) LocalAddr() net.Addr { + return nil +} + +func (r *singleEchoConn) RemoteAddr() net.Addr { + return nil +} + +func TestSingleEchoConn(t *testing.T) { + t.Parallel() + + c := &Client{ + Dial: func(addr string) (net.Conn, error) { + return &singleEchoConn{ + b: []byte("HTTP/1.1 345 OK\r\nContent-Type: foobar\r\n\r\n"), + }, nil + }, + } + + req := Request{} + res := Response{} + + req.SetRequestURI("http://foobar/a/b") + req.Header.SetMethod("POST") + req.Header.Set("Content-Type", "text/plain") + body := bytes.NewBufferString("test") + req.SetBodyStream(body, body.Len()) + + err := c.Do(&req, &res) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if res.StatusCode() != 345 { + t.Fatalf("unexpected status code: %d. Expecting 345", res.StatusCode()) + } + expected := "POST /a/b HTTP/1.1\r\nUser-Agent: fasthttp\r\nHost: foobar\r\nContent-Type: text/plain\r\nContent-Length: 4\r\n\r\ntest" + if string(res.Body()) != expected { + t.Fatalf("unexpected body: %q. Expecting %q", res.Body(), expected) + } +} + +func TestClientHTTPSInvalidServerName(t *testing.T) { + t.Parallel() + + sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:") + defer sHTTPS.Stop() + + var c Client + + for i := 0; i < 10; i++ { + _, _, err := c.GetTimeout(nil, "https://"+sHTTPS.Addr(), time.Second) + if err == nil { + t.Fatalf("expecting TLS error") + } + } +} + func TestClientHTTPSConcurrent(t *testing.T) { - addrHTTP := "127.0.0.1:56793" - sHTTP := startEchoServer(t, "tcp", addrHTTP) + t.Parallel() + + sHTTP := startEchoServer(t, "tcp", "127.0.0.1:") defer sHTTP.Stop() - addrHTTPS := "127.0.0.1:56794" - sHTTPS := startEchoServerTLS(t, "tcp", addrHTTPS) + sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:") defer sHTTPS.Stop() + c := &Client{ + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + } + var wg sync.WaitGroup for i := 0; i < 4; i++ { wg.Add(1) - addr := "http://" + addrHTTP + addr := "http://" + sHTTP.Addr() if i&1 != 0 { - addr = "https://" + addrHTTPS + addr = "https://" + sHTTPS.Addr() } go func() { defer wg.Done() - testClientGet(t, &defaultClient, addr, 20) - testClientPost(t, &defaultClient, addr, 10) + testClientGet(t, c, addr, 20) + testClientPost(t, c, addr, 10) }() } wg.Wait() } func TestClientManyServers(t *testing.T) { + t.Parallel() + var addrs []string for i := 0; i < 10; i++ { - addr := fmt.Sprintf("127.0.0.1:%d", 56904+i) - s := startEchoServer(t, "tcp", addr) + s := startEchoServer(t, "tcp", "127.0.0.1:") defer s.Stop() - addrs = append(addrs, addr) + addrs = append(addrs, s.Addr()) } var wg sync.WaitGroup @@ -641,29 +2224,30 @@ } func TestClientGet(t *testing.T) { - addr := "127.0.0.1:56789" - s := startEchoServer(t, "tcp", addr) + t.Parallel() + + s := startEchoServer(t, "tcp", "127.0.0.1:") defer s.Stop() - addr = "http://" + addr - testClientGet(t, &defaultClient, addr, 100) + testClientGet(t, &defaultClient, "http://"+s.Addr(), 100) } func TestClientPost(t *testing.T) { - addr := "127.0.0.1:56798" - s := startEchoServer(t, "tcp", addr) + t.Parallel() + + s := startEchoServer(t, "tcp", "127.0.0.1:") defer s.Stop() - addr = "http://" + addr - testClientPost(t, &defaultClient, addr, 100) + testClientPost(t, &defaultClient, "http://"+s.Addr(), 100) } func TestClientConcurrent(t *testing.T) { - addr := "127.0.0.1:55780" - s := startEchoServer(t, "tcp", addr) + t.Parallel() + + s := startEchoServer(t, "tcp", "127.0.0.1:") defer s.Stop() - addr = "http://" + addr + addr := "http://" + s.Addr() var wg sync.WaitGroup for i := 0; i < 10; i++ { wg.Add(1) @@ -687,6 +2271,8 @@ } func TestHostClientGet(t *testing.T) { + t.Parallel() + skipIfNotUnix(t) addr := "TestHostClientGet.unix" s := startEchoServer(t, "unix", addr) @@ -697,6 +2283,8 @@ } func TestHostClientPost(t *testing.T) { + t.Parallel() + skipIfNotUnix(t) addr := "./TestHostClientPost.unix" s := startEchoServer(t, "unix", addr) @@ -707,6 +2295,8 @@ } func TestHostClientConcurrent(t *testing.T) { + t.Parallel() + skipIfNotUnix(t) addr := "./TestHostClientConcurrent.unix" s := startEchoServer(t, "unix", addr) @@ -738,9 +2328,6 @@ t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK) } resultURI := string(body) - if strings.HasPrefix(uri, "https") { - resultURI = uri[:5] + resultURI[4:] - } if resultURI != uri { t.Fatalf("unexpected uri %q. Expecting %q", resultURI, uri) } @@ -856,6 +2443,10 @@ } } +func (s *testEchoServer) Addr() string { + return s.ln.Addr().String() +} + func startEchoServerTLS(t *testing.T, network, addr string) *testEchoServer { return startEchoServerExt(t, network, addr, true) } @@ -871,12 +2462,16 @@ var ln net.Listener var err error if isTLS { - certFile := "./ssl-cert-snakeoil.pem" - keyFile := "./ssl-cert-snakeoil.key" - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - t.Fatalf("Cannot load TLS certificate: %s", err) + certData, keyData, kerr := GenerateTestCertificate("localhost") + if kerr != nil { + t.Fatal(kerr) } + + cert, kerr := tls.X509KeyPair(certData, keyData) + if kerr != nil { + t.Fatal(kerr) + } + tlsConfig := &tls.Config{ Certificates: []tls.Certificate{cert}, } @@ -893,15 +2488,16 @@ if ctx.IsGet() { ctx.Success("text/plain", ctx.URI().FullURI()) } else if ctx.IsPost() { - ctx.PostArgs().WriteTo(ctx) + ctx.PostArgs().WriteTo(ctx) //nolint:errcheck } }, + Logger: &testLogger{}, // Ignore log output. } ch := make(chan struct{}) go func() { err := s.Serve(ln) if err != nil { - t.Fatalf("unexpected error returned from Serve(): %s", err) + t.Errorf("unexpected error returned from Serve(): %s", err) } close(ch) }() @@ -912,3 +2508,311 @@ t: t, } } + +func TestClientTLSHandshakeTimeout(t *testing.T) { + t.Parallel() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + addr := listener.Addr().String() + defer listener.Close() + + complete := make(chan bool) + defer close(complete) + + go func() { + conn, err := listener.Accept() + if err != nil { + t.Error(err) + return + } + <-complete + conn.Close() + }() + + client := Client{ + WriteTimeout: 100 * time.Millisecond, + ReadTimeout: 100 * time.Millisecond, + } + + _, _, err = client.Get(nil, "https://"+addr) + if err == nil { + t.Fatal("tlsClientHandshake completed successfully") + } + + if err != ErrTLSHandshakeTimeout { + t.Errorf("resulting error not a timeout: %v\nType %T: %#v", err, err, err) + } +} + +func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) { + t.Parallel() + + var ( + emptyBodyCount uint8 + ln = fasthttputil.NewInmemoryListener() + wg sync.WaitGroup + ) + + s := &Server{ + Handler: func(ctx *RequestCtx) { + if len(ctx.PostBody()) == 0 { + emptyBodyCount++ + } + time.Sleep(5 * time.Millisecond) + ctx.WriteString("foo") //nolint:errcheck + }, + } + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + c := &HostClient{ + Addr: "foobar", + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + MaxConns: 1, + MaxConnWaitTimeout: time.Second * 2, + } + + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + req := AcquireRequest() + req.SetRequestURI("http://foobar/baz") + req.Header.SetMethod(MethodPost) + req.SetBodyString("bar") + resp := AcquireResponse() + + if err := c.Do(req, resp); err != nil { + t.Errorf("unexpected error: %s", err) + } + + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + body := resp.Body() + if string(body) != "foo" { + t.Errorf("unexpected body %q. Expecting %q", body, "abcd") + } + }() + } + wg.Wait() + + if c.connsWait.len() > 0 { + t.Errorf("connsWait has %v items remaining", c.connsWait.len()) + } + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second * 5): + t.Fatalf("timeout") + } + + if emptyBodyCount > 0 { + t.Fatalf("at least one request body was empty") + } +} + +func TestHostClientMaxConnWaitTimeoutError(t *testing.T) { + t.Parallel() + + var ( + emptyBodyCount uint8 + ln = fasthttputil.NewInmemoryListener() + wg sync.WaitGroup + ) + + s := &Server{ + Handler: func(ctx *RequestCtx) { + if len(ctx.PostBody()) == 0 { + emptyBodyCount++ + } + time.Sleep(5 * time.Millisecond) + ctx.WriteString("foo") //nolint:errcheck + }, + } + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + c := &HostClient{ + Addr: "foobar", + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + MaxConns: 1, + MaxConnWaitTimeout: 10 * time.Millisecond, + } + + var errNoFreeConnsCount uint32 + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + req := AcquireRequest() + req.SetRequestURI("http://foobar/baz") + req.Header.SetMethod(MethodPost) + req.SetBodyString("bar") + resp := AcquireResponse() + + if err := c.Do(req, resp); err != nil { + if err != ErrNoFreeConns { + t.Errorf("unexpected error: %s. Expecting %s", err, ErrNoFreeConns) + } + atomic.AddUint32(&errNoFreeConnsCount, 1) + } else { + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + body := resp.Body() + if string(body) != "foo" { + t.Errorf("unexpected body %q. Expecting %q", body, "abcd") + } + } + }() + } + wg.Wait() + + // Prevent a race condition with the conns cleaner that might still be running. + c.connsLock.Lock() + defer c.connsLock.Unlock() + + if c.connsWait.len() > 0 { + t.Errorf("connsWait has %v items remaining", c.connsWait.len()) + } + if errNoFreeConnsCount == 0 { + t.Errorf("unexpected errorCount: %d. Expecting > 0", errNoFreeConnsCount) + } + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + if emptyBodyCount > 0 { + t.Fatalf("at least one request body was empty") + } +} + +func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) { + t.Parallel() + + var ( + emptyBodyCount uint8 + ln = fasthttputil.NewInmemoryListener() + wg sync.WaitGroup + // make deadline reach earlier than conns wait timeout + sleep = 100 * time.Millisecond + timeout = 10 * time.Millisecond + maxConnWaitTimeout = 50 * time.Millisecond + ) + + s := &Server{ + Handler: func(ctx *RequestCtx) { + if len(ctx.PostBody()) == 0 { + emptyBodyCount++ + } + time.Sleep(sleep) + ctx.WriteString("foo") //nolint:errcheck + }, + } + serverStopCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(serverStopCh) + }() + + c := &HostClient{ + Addr: "foobar", + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + MaxConns: 1, + MaxConnWaitTimeout: maxConnWaitTimeout, + } + + var errTimeoutCount uint32 + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + req := AcquireRequest() + req.SetRequestURI("http://foobar/baz") + req.Header.SetMethod(MethodPost) + req.SetBodyString("bar") + resp := AcquireResponse() + + if err := c.DoDeadline(req, resp, time.Now().Add(timeout)); err != nil { + if err != ErrTimeout { + t.Errorf("unexpected error: %s. Expecting %s", err, ErrTimeout) + } + atomic.AddUint32(&errTimeoutCount, 1) + } else { + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + body := resp.Body() + if string(body) != "foo" { + t.Errorf("unexpected body %q. Expecting %q", body, "abcd") + } + } + }() + } + wg.Wait() + + c.connsLock.Lock() + for { + w := c.connsWait.popFront() + if w == nil { + break + } + w.mu.Lock() + if w.err != nil && w.err != ErrTimeout { + t.Errorf("unexpected error: %s. Expecting %s", w.err, ErrTimeout) + } + w.mu.Unlock() + } + c.connsLock.Unlock() + if errTimeoutCount == 0 { + t.Errorf("unexpected errTimeoutCount: %d. Expecting > 0", errTimeoutCount) + } + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { + case <-serverStopCh: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + if emptyBodyCount > 0 { + t.Fatalf("at least one request body was empty") + } +} diff -Nru golang-github-valyala-fasthttp-20160617/client_timing_test.go golang-github-valyala-fasthttp-1.31.0/client_timing_test.go --- golang-github-valyala-fasthttp-20160617/client_timing_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/client_timing_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -51,6 +51,20 @@ return nil } +func (c *fakeClientConn) LocalAddr() net.Addr { + return &net.TCPAddr{ + IP: []byte{1, 2, 3, 4}, + Port: 8765, + } +} + +func (c *fakeClientConn) RemoteAddr() net.Addr { + return &net.TCPAddr{ + IP: []byte{1, 2, 3, 4}, + Port: 8765, + } +} + func releaseFakeServerConn(c *fakeClientConn) { c.n = 0 fakeClientConnPool.Put(c) @@ -143,7 +157,7 @@ nn := uint32(0) b.RunParallel(func(pb *testing.PB) { - req, err := http.NewRequest("GET", fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)), nil) + req, err := http.NewRequest(MethodGet, fmt.Sprintf("http://foobar%d.com/aaa/bbb", atomic.AddUint32(&nn, 1)), nil) if err != nil { b.Fatalf("unexpected error: %s", err) } @@ -172,8 +186,8 @@ } func nethttpEchoHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - w.Write([]byte(r.RequestURI)) + w.Header().Set(HeaderContentType, "text/plain") + w.Write([]byte(r.RequestURI)) //nolint:errcheck } func BenchmarkClientGetEndToEnd1TCP(b *testing.B) { @@ -199,7 +213,7 @@ ch := make(chan struct{}) go func() { if err := Serve(ln, fasthttpEchoHandler); err != nil { - b.Fatalf("error when serving requests: %s", err) + b.Errorf("error when serving requests: %s", err) } close(ch) }() @@ -260,7 +274,7 @@ go func() { if err := http.Serve(ln, http.HandlerFunc(nethttpEchoHandler)); err != nil && !strings.Contains( err.Error(), "use of closed network connection") { - b.Fatalf("error when serving requests: %s", err) + b.Errorf("error when serving requests: %s", err) } close(ch) }() @@ -328,7 +342,7 @@ ch := make(chan struct{}) go func() { if err := Serve(ln, fasthttpEchoHandler); err != nil { - b.Fatalf("error when serving requests: %s", err) + b.Errorf("error when serving requests: %s", err) } close(ch) }() @@ -389,7 +403,7 @@ go func() { if err := http.Serve(ln, http.HandlerFunc(nethttpEchoHandler)); err != nil && !strings.Contains( err.Error(), "use of closed network connection") { - b.Fatalf("error when serving requests: %s", err) + b.Errorf("error when serving requests: %s", err) } close(ch) }() @@ -444,7 +458,7 @@ bigResponse := createFixedBody(1024 * 1024) h := func(ctx *RequestCtx) { ctx.SetContentType("text/plain") - ctx.Write(bigResponse) + ctx.Write(bigResponse) //nolint:errcheck } ln := fasthttputil.NewInmemoryListener() @@ -452,7 +466,7 @@ ch := make(chan struct{}) go func() { if err := Serve(ln, h); err != nil { - b.Fatalf("error when serving requests: %s", err) + b.Errorf("error when serving requests: %s", err) } close(ch) }() @@ -502,8 +516,8 @@ func benchmarkNetHTTPClientEndToEndBigResponseInmemory(b *testing.B, parallelism int) { bigResponse := createFixedBody(1024 * 1024) h := func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - w.Write(bigResponse) + w.Header().Set(HeaderContentType, "text/plain") + w.Write(bigResponse) //nolint:errcheck } ln := fasthttputil.NewInmemoryListener() @@ -511,7 +525,7 @@ go func() { if err := http.Serve(ln, http.HandlerFunc(h)); err != nil && !strings.Contains( err.Error(), "use of closed network connection") { - b.Fatalf("error when serving requests: %s", err) + b.Errorf("error when serving requests: %s", err) } close(ch) }() @@ -528,7 +542,7 @@ url := "http://unused.host" + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { - req, err := http.NewRequest("GET", url, nil) + req, err := http.NewRequest(MethodGet, url, nil) if err != nil { b.Fatalf("unexpected error: %s", err) } @@ -577,36 +591,31 @@ func benchmarkPipelineClient(b *testing.B, parallelism int) { h := func(ctx *RequestCtx) { - ctx.WriteString("foobar") + ctx.WriteString("foobar") //nolint:errcheck } ln := fasthttputil.NewInmemoryListener() ch := make(chan struct{}) go func() { if err := Serve(ln, h); err != nil { - b.Fatalf("error when serving requests: %s", err) + b.Errorf("error when serving requests: %s", err) } close(ch) }() - var clients []*PipelineClient - for i := 0; i < runtime.GOMAXPROCS(-1); i++ { - c := &PipelineClient{ - Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, - ReadBufferSize: 1024 * 1024, - WriteBufferSize: 1024 * 1024, - MaxPendingRequests: parallelism, - } - clients = append(clients, c) + maxConns := runtime.GOMAXPROCS(-1) + c := &PipelineClient{ + Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, + ReadBufferSize: 1024 * 1024, + WriteBufferSize: 1024 * 1024, + MaxConns: maxConns, + MaxPendingRequests: parallelism * maxConns, } - clientID := uint32(0) requestURI := "/foo/bar?baz=123" url := "http://unused.host" + requestURI b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { - n := atomic.AddUint32(&clientID, 1) - c := clients[n%uint32(len(clients))] var req Request req.SetRequestURI(url) var resp Response diff -Nru golang-github-valyala-fasthttp-20160617/client_timing_wait_test.go golang-github-valyala-fasthttp-1.31.0/client_timing_wait_test.go --- golang-github-valyala-fasthttp-20160617/client_timing_wait_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/client_timing_wait_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,167 @@ +//go:build go1.11 +// +build go1.11 + +package fasthttp + +import ( + "io/ioutil" + "net" + "net/http" + "strings" + "testing" + "time" + + "github.com/valyala/fasthttp/fasthttputil" +) + +func newFasthttpSleepEchoHandler(sleep time.Duration) RequestHandler { + return func(ctx *RequestCtx) { + time.Sleep(sleep) + ctx.Success("text/plain", ctx.RequestURI()) + } +} + +func BenchmarkClientGetEndToEndWaitConn1Inmemory(b *testing.B) { + benchmarkClientGetEndToEndWaitConnInmemory(b, 1) +} + +func BenchmarkClientGetEndToEndWaitConn10Inmemory(b *testing.B) { + benchmarkClientGetEndToEndWaitConnInmemory(b, 10) +} + +func BenchmarkClientGetEndToEndWaitConn100Inmemory(b *testing.B) { + benchmarkClientGetEndToEndWaitConnInmemory(b, 100) +} + +func BenchmarkClientGetEndToEndWaitConn1000Inmemory(b *testing.B) { + benchmarkClientGetEndToEndWaitConnInmemory(b, 1000) +} + +func benchmarkClientGetEndToEndWaitConnInmemory(b *testing.B, parallelism int) { + ln := fasthttputil.NewInmemoryListener() + + ch := make(chan struct{}) + sleepDuration := 50 * time.Millisecond + go func() { + + if err := Serve(ln, newFasthttpSleepEchoHandler(sleepDuration)); err != nil { + b.Errorf("error when serving requests: %s", err) + } + close(ch) + }() + + c := &Client{ + MaxConnsPerHost: 1, + Dial: func(addr string) (net.Conn, error) { return ln.Dial() }, + MaxConnWaitTimeout: 5 * time.Second, + } + + requestURI := "/foo/bar?baz=123&sleep=10ms" + url := "http://unused.host" + requestURI + b.SetParallelism(parallelism) + b.RunParallel(func(pb *testing.PB) { + var buf []byte + for pb.Next() { + statusCode, body, err := c.Get(buf, url) + if err != nil { + if err != ErrNoFreeConns { + b.Fatalf("unexpected error: %s", err) + } + } else { + if statusCode != StatusOK { + b.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusOK) + } + if string(body) != requestURI { + b.Fatalf("unexpected response %q. Expecting %q", body, requestURI) + } + } + buf = body + } + }) + + ln.Close() + select { + case <-ch: + case <-time.After(time.Second): + b.Fatalf("server wasn't stopped") + } +} + +func newNethttpSleepEchoHandler(sleep time.Duration) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + time.Sleep(sleep) + w.Header().Set(HeaderContentType, "text/plain") + w.Write([]byte(r.RequestURI)) //nolint:errcheck + } +} + +func BenchmarkNetHTTPClientGetEndToEndWaitConn1Inmemory(b *testing.B) { + benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b, 1) +} + +func BenchmarkNetHTTPClientGetEndToEndWaitConn10Inmemory(b *testing.B) { + benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b, 10) +} + +func BenchmarkNetHTTPClientGetEndToEndWaitConn100Inmemory(b *testing.B) { + benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b, 100) +} + +func BenchmarkNetHTTPClientGetEndToEndWaitConn1000Inmemory(b *testing.B) { + benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b, 1000) +} + +func benchmarkNetHTTPClientGetEndToEndWaitConnInmemory(b *testing.B, parallelism int) { + ln := fasthttputil.NewInmemoryListener() + + ch := make(chan struct{}) + sleep := 50 * time.Millisecond + go func() { + if err := http.Serve(ln, newNethttpSleepEchoHandler(sleep)); err != nil && !strings.Contains( + err.Error(), "use of closed network connection") { + b.Errorf("error when serving requests: %s", err) + } + close(ch) + }() + + c := &http.Client{ + Transport: &http.Transport{ + Dial: func(_, _ string) (net.Conn, error) { return ln.Dial() }, + MaxConnsPerHost: 1, + }, + Timeout: 5 * time.Second, + } + + requestURI := "/foo/bar?baz=123" + url := "http://unused.host" + requestURI + b.SetParallelism(parallelism) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + resp, err := c.Get(url) + if err != nil { + if netErr, ok := err.(net.Error); !ok || !netErr.Timeout() { + b.Fatalf("unexpected error: %s", err) + } + } else { + if resp.StatusCode != http.StatusOK { + b.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode, http.StatusOK) + } + body, err := ioutil.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { + b.Fatalf("unexpected error when reading response body: %s", err) + } + if string(body) != requestURI { + b.Fatalf("unexpected response %q. Expecting %q", body, requestURI) + } + } + } + }) + + ln.Close() + select { + case <-ch: + case <-time.After(time.Second): + b.Fatalf("server wasn't stopped") + } +} diff -Nru golang-github-valyala-fasthttp-20160617/coarseTime.go golang-github-valyala-fasthttp-1.31.0/coarseTime.go --- golang-github-valyala-fasthttp-20160617/coarseTime.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/coarseTime.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,13 @@ +package fasthttp + +import ( + "time" +) + +// CoarseTimeNow returns the current time truncated to the nearest second. +// +// Deprecated: This is slower than calling time.Now() directly. +// This is now time.Now().Truncate(time.Second) shortcut. +func CoarseTimeNow() time.Time { + return time.Now().Truncate(time.Second) +} diff -Nru golang-github-valyala-fasthttp-20160617/coarseTime_test.go golang-github-valyala-fasthttp-1.31.0/coarseTime_test.go --- golang-github-valyala-fasthttp-20160617/coarseTime_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/coarseTime_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,37 @@ +package fasthttp + +import ( + "sync/atomic" + "testing" + "time" +) + +func BenchmarkCoarseTimeNow(b *testing.B) { + var zeroTimeCount uint64 + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + t := CoarseTimeNow() + if t.IsZero() { + atomic.AddUint64(&zeroTimeCount, 1) + } + } + }) + if zeroTimeCount > 0 { + b.Fatalf("zeroTimeCount must be zero") + } +} + +func BenchmarkTimeNow(b *testing.B) { + var zeroTimeCount uint64 + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + t := time.Now() + if t.IsZero() { + atomic.AddUint64(&zeroTimeCount, 1) + } + } + }) + if zeroTimeCount > 0 { + b.Fatalf("zeroTimeCount must be zero") + } +} diff -Nru golang-github-valyala-fasthttp-20160617/compress.go golang-github-valyala-fasthttp-1.31.0/compress.go --- golang-github-valyala-fasthttp-20160617/compress.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/compress.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,6 +1,7 @@ package fasthttp import ( + "bytes" "fmt" "io" "os" @@ -9,6 +10,8 @@ "github.com/klauspost/compress/flate" "github.com/klauspost/compress/gzip" "github.com/klauspost/compress/zlib" + "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp/stackless" ) // Supported compression levels. @@ -16,7 +19,8 @@ CompressNoCompression = flate.NoCompression CompressBestSpeed = flate.BestSpeed CompressBestCompression = flate.BestCompression - CompressDefaultCompression = flate.DefaultCompression + CompressDefaultCompression = 6 // flate.DefaultCompression + CompressHuffmanOnly = -2 // flate.HuffmanOnly ) func acquireGzipReader(r io.Reader) (*gzip.Reader, error) { @@ -69,48 +73,54 @@ var flateReaderPool sync.Pool -func acquireGzipWriter(w io.Writer, level int) *gzipWriter { - p := gzipWriterPoolMap[level] - if p == nil { - panic(fmt.Sprintf("BUG: unexpected compression level passed: %d. See compress/gzip for supported levels", level)) +func acquireStacklessGzipWriter(w io.Writer, level int) stackless.Writer { + nLevel := normalizeCompressLevel(level) + p := stacklessGzipWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + return stackless.NewWriter(w, func(w io.Writer) stackless.Writer { + return acquireRealGzipWriter(w, level) + }) } + sw := v.(stackless.Writer) + sw.Reset(w) + return sw +} +func releaseStacklessGzipWriter(sw stackless.Writer, level int) { + sw.Close() + nLevel := normalizeCompressLevel(level) + p := stacklessGzipWriterPoolMap[nLevel] + p.Put(sw) +} + +func acquireRealGzipWriter(w io.Writer, level int) *gzip.Writer { + nLevel := normalizeCompressLevel(level) + p := realGzipWriterPoolMap[nLevel] v := p.Get() if v == nil { zw, err := gzip.NewWriterLevel(w, level) if err != nil { panic(fmt.Sprintf("BUG: unexpected error from gzip.NewWriterLevel(%d): %s", level, err)) } - return &gzipWriter{ - Writer: zw, - p: p, - } + return zw } - zw := v.(*gzipWriter) + zw := v.(*gzip.Writer) zw.Reset(w) return zw } -func releaseGzipWriter(zw *gzipWriter) { +func releaseRealGzipWriter(zw *gzip.Writer, level int) { zw.Close() - zw.p.Put(zw) + nLevel := normalizeCompressLevel(level) + p := realGzipWriterPoolMap[nLevel] + p.Put(zw) } -type gzipWriter struct { - *gzip.Writer - p *sync.Pool -} - -var gzipWriterPoolMap = func() map[int]*sync.Pool { - // Initialize pools for all the compression levels defined - // in https://golang.org/pkg/compress/gzip/#pkg-constants . - m := make(map[int]*sync.Pool, 11) - m[-1] = &sync.Pool{} - for i := 0; i < 10; i++ { - m[i] = &sync.Pool{} - } - return m -}() +var ( + stacklessGzipWriterPoolMap = newCompressWriterPoolMap() + realGzipWriterPoolMap = newCompressWriterPoolMap() +) // AppendGzipBytesLevel appends gzipped src to dst using the given // compression level and returns the resulting dst. @@ -121,9 +131,10 @@ // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression +// * CompressHuffmanOnly func AppendGzipBytesLevel(dst, src []byte, level int) []byte { w := &byteSliceWriter{dst} - WriteGzipLevel(w, src, level) + WriteGzipLevel(w, src, level) //nolint:errcheck return w.b } @@ -136,11 +147,40 @@ // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression +// * CompressHuffmanOnly func WriteGzipLevel(w io.Writer, p []byte, level int) (int, error) { - zw := acquireGzipWriter(w, level) - n, err := zw.Write(p) - releaseGzipWriter(zw) - return n, err + switch w.(type) { + case *byteSliceWriter, + *bytes.Buffer, + *bytebufferpool.ByteBuffer: + // These writers don't block, so we can just use stacklessWriteGzip + ctx := &compressCtx{ + w: w, + p: p, + level: level, + } + stacklessWriteGzip(ctx) + return len(p), nil + default: + zw := acquireStacklessGzipWriter(w, level) + n, err := zw.Write(p) + releaseStacklessGzipWriter(zw, level) + return n, err + } +} + +var stacklessWriteGzip = stackless.NewFunc(nonblockingWriteGzip) + +func nonblockingWriteGzip(ctxv interface{}) { + ctx := ctxv.(*compressCtx) + zw := acquireRealGzipWriter(ctx.w, ctx.level) + + _, err := zw.Write(ctx.p) + if err != nil { + panic(fmt.Sprintf("BUG: gzip.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err)) + } + + releaseRealGzipWriter(zw, ctx.level) } // WriteGzip writes gzipped p to w and returns the number of compressed @@ -171,6 +211,91 @@ return nn, err } +// AppendGunzipBytes appends gunzipped src to dst and returns the resulting dst. +func AppendGunzipBytes(dst, src []byte) ([]byte, error) { + w := &byteSliceWriter{dst} + _, err := WriteGunzip(w, src) + return w.b, err +} + +// AppendDeflateBytesLevel appends deflated src to dst using the given +// compression level and returns the resulting dst. +// +// Supported compression levels are: +// +// * CompressNoCompression +// * CompressBestSpeed +// * CompressBestCompression +// * CompressDefaultCompression +// * CompressHuffmanOnly +func AppendDeflateBytesLevel(dst, src []byte, level int) []byte { + w := &byteSliceWriter{dst} + WriteDeflateLevel(w, src, level) //nolint:errcheck + return w.b +} + +// WriteDeflateLevel writes deflated p to w using the given compression level +// and returns the number of compressed bytes written to w. +// +// Supported compression levels are: +// +// * CompressNoCompression +// * CompressBestSpeed +// * CompressBestCompression +// * CompressDefaultCompression +// * CompressHuffmanOnly +func WriteDeflateLevel(w io.Writer, p []byte, level int) (int, error) { + switch w.(type) { + case *byteSliceWriter, + *bytes.Buffer, + *bytebufferpool.ByteBuffer: + // These writers don't block, so we can just use stacklessWriteDeflate + ctx := &compressCtx{ + w: w, + p: p, + level: level, + } + stacklessWriteDeflate(ctx) + return len(p), nil + default: + zw := acquireStacklessDeflateWriter(w, level) + n, err := zw.Write(p) + releaseStacklessDeflateWriter(zw, level) + return n, err + } +} + +var stacklessWriteDeflate = stackless.NewFunc(nonblockingWriteDeflate) + +func nonblockingWriteDeflate(ctxv interface{}) { + ctx := ctxv.(*compressCtx) + zw := acquireRealDeflateWriter(ctx.w, ctx.level) + + _, err := zw.Write(ctx.p) + if err != nil { + panic(fmt.Sprintf("BUG: zlib.Writer.Write for len(p)=%d returned unexpected error: %s", len(ctx.p), err)) + } + + releaseRealDeflateWriter(zw, ctx.level) +} + +type compressCtx struct { + w io.Writer + p []byte + level int +} + +// WriteDeflate writes deflated p to w and returns the number of compressed +// bytes written to w. +func WriteDeflate(w io.Writer, p []byte) (int, error) { + return WriteDeflateLevel(w, p, CompressDefaultCompression) +} + +// AppendDeflateBytes appends deflated src to dst and returns the resulting dst. +func AppendDeflateBytes(dst, src []byte) []byte { + return AppendDeflateBytesLevel(dst, src, CompressDefaultCompression) +} + // WriteInflate writes inflated p to w and returns the number of uncompressed // bytes written to w. func WriteInflate(w io.Writer, p []byte) (int, error) { @@ -188,10 +313,10 @@ return nn, err } -// AppendGunzipBytes append gunzipped src to dst and returns the resulting dst. -func AppendGunzipBytes(dst, src []byte) ([]byte, error) { +// AppendInflateBytes appends inflated src to dst and returns the resulting dst. +func AppendInflateBytes(dst, src []byte) ([]byte, error) { w := &byteSliceWriter{dst} - _, err := WriteGunzip(w, src) + _, err := WriteInflate(w, src) return w.b, err } @@ -217,68 +342,106 @@ return n, nil } -func acquireFlateWriter(w io.Writer, level int) *flateWriter { - p := flateWriterPoolMap[level] - if p == nil { - panic(fmt.Sprintf("BUG: unexpected compression level passed: %d. See compress/flate for supported levels", level)) +func (r *byteSliceReader) ReadByte() (byte, error) { + if len(r.b) == 0 { + return 0, io.EOF } + n := r.b[0] + r.b = r.b[1:] + return n, nil +} +func acquireStacklessDeflateWriter(w io.Writer, level int) stackless.Writer { + nLevel := normalizeCompressLevel(level) + p := stacklessDeflateWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + return stackless.NewWriter(w, func(w io.Writer) stackless.Writer { + return acquireRealDeflateWriter(w, level) + }) + } + sw := v.(stackless.Writer) + sw.Reset(w) + return sw +} + +func releaseStacklessDeflateWriter(sw stackless.Writer, level int) { + sw.Close() + nLevel := normalizeCompressLevel(level) + p := stacklessDeflateWriterPoolMap[nLevel] + p.Put(sw) +} + +func acquireRealDeflateWriter(w io.Writer, level int) *zlib.Writer { + nLevel := normalizeCompressLevel(level) + p := realDeflateWriterPoolMap[nLevel] v := p.Get() if v == nil { zw, err := zlib.NewWriterLevel(w, level) if err != nil { - panic(fmt.Sprintf("BUG: unexpected error in zlib.NewWriterLevel(%d): %s", level, err)) - } - return &flateWriter{ - Writer: zw, - p: p, + panic(fmt.Sprintf("BUG: unexpected error from zlib.NewWriterLevel(%d): %s", level, err)) } + return zw } - zw := v.(*flateWriter) + zw := v.(*zlib.Writer) zw.Reset(w) return zw } -func releaseFlateWriter(zw *flateWriter) { +func releaseRealDeflateWriter(zw *zlib.Writer, level int) { zw.Close() - zw.p.Put(zw) + nLevel := normalizeCompressLevel(level) + p := realDeflateWriterPoolMap[nLevel] + p.Put(zw) } -type flateWriter struct { - *zlib.Writer - p *sync.Pool -} +var ( + stacklessDeflateWriterPoolMap = newCompressWriterPoolMap() + realDeflateWriterPoolMap = newCompressWriterPoolMap() +) -var flateWriterPoolMap = func() map[int]*sync.Pool { +func newCompressWriterPoolMap() []*sync.Pool { // Initialize pools for all the compression levels defined // in https://golang.org/pkg/compress/flate/#pkg-constants . - m := make(map[int]*sync.Pool, 11) - m[-1] = &sync.Pool{} - for i := 0; i < 10; i++ { - m[i] = &sync.Pool{} + // Compression levels are normalized with normalizeCompressLevel, + // so the fit [0..11]. + var m []*sync.Pool + for i := 0; i < 12; i++ { + m = append(m, &sync.Pool{}) } return m -}() +} func isFileCompressible(f *os.File, minCompressRatio float64) bool { // Try compressing the first 4kb of of the file // and see if it can be compressed by more than // the given minCompressRatio. - b := AcquireByteBuffer() - zw := acquireGzipWriter(b, CompressDefaultCompression) + b := bytebufferpool.Get() + zw := acquireStacklessGzipWriter(b, CompressDefaultCompression) lr := &io.LimitedReader{ R: f, N: 4096, } _, err := copyZeroAlloc(zw, lr) - releaseGzipWriter(zw) - f.Seek(0, 0) + releaseStacklessGzipWriter(zw, CompressDefaultCompression) + f.Seek(0, 0) //nolint:errcheck if err != nil { return false } n := 4096 - lr.N zn := len(b.B) - ReleaseByteBuffer(b) + bytebufferpool.Put(b) return float64(zn) < float64(n)*minCompressRatio } + +// normalizes compression level into [0..11], so it could be used as an index +// in *PoolMap. +func normalizeCompressLevel(level int) int { + // -2 is the lowest compression level - CompressHuffmanOnly + // 9 is the highest compression level - CompressBestCompression + if level < -2 || level > 9 { + level = CompressDefaultCompression + } + return level + 2 +} diff -Nru golang-github-valyala-fasthttp-20160617/compress_test.go golang-github-valyala-fasthttp-1.31.0/compress_test.go --- golang-github-valyala-fasthttp-20160617/compress_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/compress_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -2,88 +2,231 @@ import ( "bytes" + "fmt" "io/ioutil" "testing" + "time" ) -func TestGzipBytes(t *testing.T) { - testGzipBytes(t, "") - testGzipBytes(t, "foobar") - testGzipBytes(t, "выфаодлодл одлфываыв sd2 k34") +var compressTestcases = func() []string { + a := []string{ + "", + "foobar", + "выфаодлодл одлфываыв sd2 k34", + } + bigS := createFixedBody(1e4) + a = append(a, string(bigS)) + return a +}() + +func TestGzipBytesSerial(t *testing.T) { + t.Parallel() + + if err := testGzipBytes(); err != nil { + t.Fatal(err) + } +} + +func TestGzipBytesConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testGzipBytes); err != nil { + t.Fatal(err) + } +} + +func TestDeflateBytesSerial(t *testing.T) { + t.Parallel() + + if err := testDeflateBytes(); err != nil { + t.Fatal(err) + } +} + +func TestDeflateBytesConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testDeflateBytes); err != nil { + t.Fatal(err) + } } -func testGzipBytes(t *testing.T, s string) { +func testGzipBytes() error { + for _, s := range compressTestcases { + if err := testGzipBytesSingleCase(s); err != nil { + return err + } + } + return nil +} + +func testDeflateBytes() error { + for _, s := range compressTestcases { + if err := testDeflateBytesSingleCase(s); err != nil { + return err + } + } + return nil +} + +func testGzipBytesSingleCase(s string) error { prefix := []byte("foobar") gzippedS := AppendGzipBytes(prefix, []byte(s)) if !bytes.Equal(gzippedS[:len(prefix)], prefix) { - t.Fatalf("unexpected prefix when compressing %q: %q. Expecting %q", s, gzippedS[:len(prefix)], prefix) + return fmt.Errorf("unexpected prefix when compressing %q: %q. Expecting %q", s, gzippedS[:len(prefix)], prefix) } gunzippedS, err := AppendGunzipBytes(prefix, gzippedS[len(prefix):]) if err != nil { - t.Fatalf("unexpected error when uncompressing %q: %s", s, err) + return fmt.Errorf("unexpected error when uncompressing %q: %s", s, err) } if !bytes.Equal(gunzippedS[:len(prefix)], prefix) { - t.Fatalf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, gunzippedS[:len(prefix)], prefix) + return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, gunzippedS[:len(prefix)], prefix) } gunzippedS = gunzippedS[len(prefix):] if string(gunzippedS) != s { - t.Fatalf("unexpected uncompressed string %q. Expecting %q", gunzippedS, s) + return fmt.Errorf("unexpected uncompressed string %q. Expecting %q", gunzippedS, s) + } + return nil +} + +func testDeflateBytesSingleCase(s string) error { + prefix := []byte("foobar") + deflatedS := AppendDeflateBytes(prefix, []byte(s)) + if !bytes.Equal(deflatedS[:len(prefix)], prefix) { + return fmt.Errorf("unexpected prefix when compressing %q: %q. Expecting %q", s, deflatedS[:len(prefix)], prefix) + } + + inflatedS, err := AppendInflateBytes(prefix, deflatedS[len(prefix):]) + if err != nil { + return fmt.Errorf("unexpected error when uncompressing %q: %s", s, err) + } + if !bytes.Equal(inflatedS[:len(prefix)], prefix) { + return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, inflatedS[:len(prefix)], prefix) + } + inflatedS = inflatedS[len(prefix):] + if string(inflatedS) != s { + return fmt.Errorf("unexpected uncompressed string %q. Expecting %q", inflatedS, s) + } + return nil +} + +func TestGzipCompressSerial(t *testing.T) { + t.Parallel() + + if err := testGzipCompress(); err != nil { + t.Fatal(err) + } +} + +func TestGzipCompressConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testGzipCompress); err != nil { + t.Fatal(err) } } -func TestGzipCompress(t *testing.T) { - testGzipCompress(t, "") - testGzipCompress(t, "foobar") - testGzipCompress(t, "ajjnkn asdlkjfqoijfw jfqkwj foj eowjiq") +func TestFlateCompressSerial(t *testing.T) { + t.Parallel() + + if err := testFlateCompress(); err != nil { + t.Fatal(err) + } } -func TestFlateCompress(t *testing.T) { - testFlateCompress(t, "") - testFlateCompress(t, "foobar") - testFlateCompress(t, "adf asd asd fasd fasd") +func TestFlateCompressConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testFlateCompress); err != nil { + t.Fatal(err) + } +} + +func testGzipCompress() error { + for _, s := range compressTestcases { + if err := testGzipCompressSingleCase(s); err != nil { + return err + } + } + return nil } -func testGzipCompress(t *testing.T, s string) { +func testFlateCompress() error { + for _, s := range compressTestcases { + if err := testFlateCompressSingleCase(s); err != nil { + return err + } + } + return nil +} + +func testGzipCompressSingleCase(s string) error { var buf bytes.Buffer - zw := acquireGzipWriter(&buf, CompressDefaultCompression) + zw := acquireStacklessGzipWriter(&buf, CompressDefaultCompression) if _, err := zw.Write([]byte(s)); err != nil { - t.Fatalf("unexpected error: %s. s=%q", err, s) + return fmt.Errorf("unexpected error: %s. s=%q", err, s) } - releaseGzipWriter(zw) + releaseStacklessGzipWriter(zw, CompressDefaultCompression) zr, err := acquireGzipReader(&buf) if err != nil { - t.Fatalf("unexpected error: %s. s=%q", err, s) + return fmt.Errorf("unexpected error: %s. s=%q", err, s) } body, err := ioutil.ReadAll(zr) if err != nil { - t.Fatalf("unexpected error: %s. s=%q", err, s) + return fmt.Errorf("unexpected error: %s. s=%q", err, s) } if string(body) != s { - t.Fatalf("unexpected string after decompression: %q. Expecting %q", body, s) + return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s) } releaseGzipReader(zr) + return nil } -func testFlateCompress(t *testing.T, s string) { +func testFlateCompressSingleCase(s string) error { var buf bytes.Buffer - zw := acquireFlateWriter(&buf, CompressDefaultCompression) + zw := acquireStacklessDeflateWriter(&buf, CompressDefaultCompression) if _, err := zw.Write([]byte(s)); err != nil { - t.Fatalf("unexpected error: %s. s=%q", err, s) + return fmt.Errorf("unexpected error: %s. s=%q", err, s) } - releaseFlateWriter(zw) + releaseStacklessDeflateWriter(zw, CompressDefaultCompression) zr, err := acquireFlateReader(&buf) if err != nil { - t.Fatalf("unexpected error: %s. s=%q", err, s) + return fmt.Errorf("unexpected error: %s. s=%q", err, s) } body, err := ioutil.ReadAll(zr) if err != nil { - t.Fatalf("unexpected error: %s. s=%q", err, s) + return fmt.Errorf("unexpected error: %s. s=%q", err, s) } if string(body) != s { - t.Fatalf("unexpected string after decompression: %q. Expecting %q", body, s) + return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s) } releaseFlateReader(zr) + return nil +} + +func testConcurrent(concurrency int, f func() error) error { + ch := make(chan error, concurrency) + for i := 0; i < concurrency; i++ { + go func(idx int) { + err := f() + if err != nil { + ch <- fmt.Errorf("error in goroutine %d: %s", idx, err) + } + ch <- nil + }(i) + } + for i := 0; i < concurrency; i++ { + select { + case err := <-ch: + if err != nil { + return err + } + case <-time.After(time.Second): + return fmt.Errorf("timeout") + } + } + return nil } diff -Nru golang-github-valyala-fasthttp-20160617/cookie.go golang-github-valyala-fasthttp-1.31.0/cookie.go --- golang-github-valyala-fasthttp-20160617/cookie.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/cookie.go 2021-10-09 18:39:05.000000000 +0000 @@ -18,6 +18,24 @@ CookieExpireUnlimited = zeroTime ) +// CookieSameSite is an enum for the mode in which the SameSite flag should be set for the given cookie. +// See https://tools.ietf.org/html/draft-ietf-httpbis-cookie-same-site-00 for details. +type CookieSameSite int + +const ( + // CookieSameSiteDisabled removes the SameSite flag + CookieSameSiteDisabled CookieSameSite = iota + // CookieSameSiteDefaultMode sets the SameSite flag + CookieSameSiteDefaultMode + // CookieSameSiteLaxMode sets the SameSite flag with the "Lax" parameter + CookieSameSiteLaxMode + // CookieSameSiteStrictMode sets the SameSite flag with the "Strict" parameter + CookieSameSiteStrictMode + // CookieSameSiteNoneMode sets the SameSite flag with the "None" parameter + // see https://tools.ietf.org/html/draft-west-cookie-incrementalism-00 + CookieSameSiteNoneMode +) + // AcquireCookie returns an empty Cookie object from the pool. // // The returned object may be returned back to the pool with ReleaseCookie. @@ -47,16 +65,18 @@ // // Cookie instance MUST NOT be used from concurrently running goroutines. type Cookie struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck key []byte value []byte expire time.Time + maxAge int domain []byte path []byte httpOnly bool secure bool + sameSite CookieSameSite bufKV argsKV buf []byte @@ -65,13 +85,15 @@ // CopyTo copies src cookie to c. func (c *Cookie) CopyTo(src *Cookie) { c.Reset() - c.key = append(c.key[:0], src.key...) - c.value = append(c.value[:0], src.value...) + c.key = append(c.key, src.key...) + c.value = append(c.value, src.value...) c.expire = src.expire - c.domain = append(c.domain[:0], src.domain...) - c.path = append(c.path[:0], src.path...) + c.maxAge = src.maxAge + c.domain = append(c.domain, src.domain...) + c.path = append(c.path, src.path...) c.httpOnly = src.httpOnly c.secure = src.secure + c.sameSite = src.sameSite } // HTTPOnly returns true if the cookie is http only. @@ -94,6 +116,20 @@ c.secure = secure } +// SameSite returns the SameSite mode. +func (c *Cookie) SameSite() CookieSameSite { + return c.sameSite +} + +// SetSameSite sets the cookie's SameSite flag to the given value. +// set value CookieSameSiteNoneMode will set Secure to true also to avoid browser rejection +func (c *Cookie) SetSameSite(mode CookieSameSite) { + c.sameSite = mode + if mode == CookieSameSiteNoneMode { + c.SetSecure(true) + } +} + // Path returns cookie path. func (c *Cookie) Path() []byte { return c.path @@ -113,7 +149,8 @@ // Domain returns cookie domain. // -// The returned domain is valid until the next Cookie modification method call. +// The returned value is valid until the Cookie reused or released (ReleaseCookie). +// Do not store references to the returned value. Make copies instead. func (c *Cookie) Domain() []byte { return c.domain } @@ -128,6 +165,20 @@ c.domain = append(c.domain[:0], domain...) } +// MaxAge returns the seconds until the cookie is meant to expire or 0 +// if no max age. +func (c *Cookie) MaxAge() int { + return c.maxAge +} + +// SetMaxAge sets cookie expiration time based on seconds. This takes precedence +// over any absolute expiry set on the cookie +// +// Set max age to 0 to unset +func (c *Cookie) SetMaxAge(seconds int) { + c.maxAge = seconds +} + // Expire returns cookie expiration time. // // CookieExpireUnlimited is returned if cookie doesn't expire @@ -151,7 +202,8 @@ // Value returns cookie value. // -// The returned value is valid until the next Cookie modification method call. +// The returned value is valid until the Cookie reused or released (ReleaseCookie). +// Do not store references to the returned value. Make copies instead. func (c *Cookie) Value() []byte { return c.value } @@ -168,7 +220,8 @@ // Key returns cookie name. // -// The returned value is valid until the next Cookie modification method call. +// The returned value is valid until the Cookie reused or released (ReleaseCookie). +// Do not store references to the returned value. Make copies instead. func (c *Cookie) Key() []byte { return c.key } @@ -188,22 +241,29 @@ c.key = c.key[:0] c.value = c.value[:0] c.expire = zeroTime + c.maxAge = 0 c.domain = c.domain[:0] c.path = c.path[:0] c.httpOnly = false c.secure = false + c.sameSite = CookieSameSiteDisabled } // AppendBytes appends cookie representation to dst and returns // the extended dst. func (c *Cookie) AppendBytes(dst []byte) []byte { if len(c.key) > 0 { - dst = AppendQuotedArg(dst, c.key) + dst = append(dst, c.key...) dst = append(dst, '=') } - dst = AppendQuotedArg(dst, c.value) + dst = append(dst, c.value...) - if !c.expire.IsZero() { + if c.maxAge > 0 { + dst = append(dst, ';', ' ') + dst = append(dst, strCookieMaxAge...) + dst = append(dst, '=') + dst = AppendUint(dst, c.maxAge) + } else if !c.expire.IsZero() { c.bufKV.value = AppendHTTPDate(c.bufKV.value[:0], c.expire) dst = append(dst, ';', ' ') dst = append(dst, strCookieExpires...) @@ -224,12 +284,33 @@ dst = append(dst, ';', ' ') dst = append(dst, strCookieSecure...) } + switch c.sameSite { + case CookieSameSiteDefaultMode: + dst = append(dst, ';', ' ') + dst = append(dst, strCookieSameSite...) + case CookieSameSiteLaxMode: + dst = append(dst, ';', ' ') + dst = append(dst, strCookieSameSite...) + dst = append(dst, '=') + dst = append(dst, strCookieSameSiteLax...) + case CookieSameSiteStrictMode: + dst = append(dst, ';', ' ') + dst = append(dst, strCookieSameSite...) + dst = append(dst, '=') + dst = append(dst, strCookieSameSiteStrict...) + case CookieSameSiteNoneMode: + dst = append(dst, ';', ' ') + dst = append(dst, strCookieSameSite...) + dst = append(dst, '=') + dst = append(dst, strCookieSameSiteNone...) + } return dst } // Cookie returns cookie representation. // -// The returned value is valid until the next call to Cookie methods. +// The returned value is valid until the Cookie reused or released (ReleaseCookie). +// Do not store references to the returned value. Make copies instead. func (c *Cookie) Cookie() []byte { c.buf = c.AppendBytes(c.buf[:0]) return c.buf @@ -264,37 +345,89 @@ s.b = src kv := &c.bufKV - if !s.next(kv, true) { + if !s.next(kv) { return errNoCookies } - c.key = append(c.key[:0], kv.key...) - c.value = append(c.value[:0], kv.value...) + c.key = append(c.key, kv.key...) + c.value = append(c.value, kv.value...) - for s.next(kv, false) { - if len(kv.key) == 0 && len(kv.value) == 0 { - continue - } - switch string(kv.key) { - case "expires": - v := b2s(kv.value) - exptime, err := time.ParseInLocation(time.RFC1123, v, time.UTC) - if err != nil { - return err + for s.next(kv) { + if len(kv.key) != 0 { + // Case insensitive switch on first char + switch kv.key[0] | 0x20 { + case 'm': + if caseInsensitiveCompare(strCookieMaxAge, kv.key) { + maxAge, err := ParseUint(kv.value) + if err != nil { + return err + } + c.maxAge = maxAge + } + + case 'e': // "expires" + if caseInsensitiveCompare(strCookieExpires, kv.key) { + v := b2s(kv.value) + // Try the same two formats as net/http + // See: https://github.com/golang/go/blob/00379be17e63a5b75b3237819392d2dc3b313a27/src/net/http/cookie.go#L133-L135 + exptime, err := time.ParseInLocation(time.RFC1123, v, time.UTC) + if err != nil { + exptime, err = time.Parse("Mon, 02-Jan-2006 15:04:05 MST", v) + if err != nil { + return err + } + } + c.expire = exptime + } + + case 'd': // "domain" + if caseInsensitiveCompare(strCookieDomain, kv.key) { + c.domain = append(c.domain, kv.value...) + } + + case 'p': // "path" + if caseInsensitiveCompare(strCookiePath, kv.key) { + c.path = append(c.path, kv.value...) + } + + case 's': // "samesite" + if caseInsensitiveCompare(strCookieSameSite, kv.key) { + if len(kv.value) > 0 { + // Case insensitive switch on first char + switch kv.value[0] | 0x20 { + case 'l': // "lax" + if caseInsensitiveCompare(strCookieSameSiteLax, kv.value) { + c.sameSite = CookieSameSiteLaxMode + } + case 's': // "strict" + if caseInsensitiveCompare(strCookieSameSiteStrict, kv.value) { + c.sameSite = CookieSameSiteStrictMode + } + case 'n': // "none" + if caseInsensitiveCompare(strCookieSameSiteNone, kv.value) { + c.sameSite = CookieSameSiteNoneMode + } + } + } + } } - c.expire = exptime - case "domain": - c.domain = append(c.domain[:0], kv.value...) - case "path": - c.path = append(c.path[:0], kv.value...) - case "": - switch string(kv.value) { - case "HttpOnly": - c.httpOnly = true - case "secure": - c.secure = true + + } else if len(kv.value) != 0 { + // Case insensitive switch on first char + switch kv.value[0] | 0x20 { + case 'h': // "httponly" + if caseInsensitiveCompare(strCookieHTTPOnly, kv.value) { + c.httpOnly = true + } + + case 's': // "secure" + if caseInsensitiveCompare(strCookieSecure, kv.value) { + c.secure = true + } else if caseInsensitiveCompare(strCookieSameSite, kv.value) { + c.sameSite = CookieSameSiteDefaultMode + } } - } + } // else empty or no match } return nil } @@ -311,17 +444,30 @@ if n >= 0 { src = src[:n] } - return decodeCookieArg(dst, src, true) + return decodeCookieArg(dst, src, false) } func appendRequestCookieBytes(dst []byte, cookies []argsKV) []byte { for i, n := 0, len(cookies); i < n; i++ { kv := &cookies[i] if len(kv.key) > 0 { - dst = AppendQuotedArg(dst, kv.key) + dst = append(dst, kv.key...) dst = append(dst, '=') } - dst = AppendQuotedArg(dst, kv.value) + dst = append(dst, kv.value...) + if i+1 < n { + dst = append(dst, ';', ' ') + } + } + return dst +} + +// For Response we can not use the above function as response cookies +// already contain the key= in the value. +func appendResponseCookieBytes(dst []byte, cookies []argsKV) []byte { + for i, n := 0, len(cookies); i < n; i++ { + kv := &cookies[i] + dst = append(dst, kv.value...) if i+1 < n { dst = append(dst, ';', ' ') } @@ -334,7 +480,7 @@ s.b = src var kv *argsKV cookies, kv = allocArg(cookies) - for s.next(kv, true) { + for s.next(kv) { if len(kv.key) > 0 || len(kv.value) > 0 { cookies, kv = allocArg(cookies) } @@ -346,7 +492,7 @@ b []byte } -func (s *cookieScanner) next(kv *argsKV, decode bool) bool { +func (s *cookieScanner) next(kv *argsKV) bool { b := s.b if len(b) == 0 { return false @@ -359,14 +505,14 @@ case '=': if isKey { isKey = false - kv.key = decodeCookieArg(kv.key, b[:i], decode) + kv.key = decodeCookieArg(kv.key, b[:i], false) k = i + 1 } case ';': if isKey { kv.key = kv.key[:0] } - kv.value = decodeCookieArg(kv.value, b[k:i], decode) + kv.value = decodeCookieArg(kv.value, b[k:i], true) s.b = b[i+1:] return true } @@ -375,20 +521,36 @@ if isKey { kv.key = kv.key[:0] } - kv.value = decodeCookieArg(kv.value, b[k:], decode) + kv.value = decodeCookieArg(kv.value, b[k:], true) s.b = b[len(b):] return true } -func decodeCookieArg(dst, src []byte, decode bool) []byte { +func decodeCookieArg(dst, src []byte, skipQuotes bool) []byte { for len(src) > 0 && src[0] == ' ' { src = src[1:] } for len(src) > 0 && src[len(src)-1] == ' ' { src = src[:len(src)-1] } - if !decode { - return append(dst[:0], src...) + if skipQuotes { + if len(src) > 1 && src[0] == '"' && src[len(src)-1] == '"' { + src = src[1 : len(src)-1] + } + } + return append(dst[:0], src...) +} + +// caseInsensitiveCompare does a case insensitive equality comparison of +// two []byte. Assumes only letters need to be matched. +func caseInsensitiveCompare(a, b []byte) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if a[i]|0x20 != b[i]|0x20 { + return false + } } - return decodeArg(dst, src, true) + return true } diff -Nru golang-github-valyala-fasthttp-20160617/cookie_test.go golang-github-valyala-fasthttp-1.31.0/cookie_test.go --- golang-github-valyala-fasthttp-20160617/cookie_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/cookie_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -6,7 +6,52 @@ "time" ) +func TestCookiePanic(t *testing.T) { + t.Parallel() + + var c Cookie + if err := c.Parse(";SAMeSITe="); err != nil { + t.Error(err) + } +} + +func TestCookieValueWithEqualAndSpaceChars(t *testing.T) { + t.Parallel() + + testCookieValueWithEqualAndSpaceChars(t, "sth1", "/", "MTQ2NjU5NTcwN3xfUVduVXk4aG9jSmZaNzNEb1dGa1VjekY1bG9vMmxSWlJBZUN2Q1ZtZVFNMTk2YU9YaWtCVmY1eDRWZXd3M3Q5RTJRZnZMbk5mWklSSFZJcVlXTDhiSFFHWWdpdFVLd1hwbXR2UUN4QlJ1N3BITFpkS3Y4PXzDvPNn6JVDBFB2wYVYPHdkdlZBm6n1_0QB3_GWwE40Tg ==") + testCookieValueWithEqualAndSpaceChars(t, "sth2", "/", "123") + testCookieValueWithEqualAndSpaceChars(t, "sth3", "/", "123 == 1") +} + +func testCookieValueWithEqualAndSpaceChars(t *testing.T, expectedName, expectedPath, expectedValue string) { + var c Cookie + c.SetKey(expectedName) + c.SetPath(expectedPath) + c.SetValue(expectedValue) + + s := c.String() + + var c1 Cookie + if err := c1.Parse(s); err != nil { + t.Fatalf("unexpected error: %s", err) + } + name := c1.Key() + if string(name) != expectedName { + t.Fatalf("unexpected name %q. Expecting %q", name, expectedName) + } + path := c1.Path() + if string(path) != expectedPath { + t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath) + } + value := c1.Value() + if string(value) != expectedValue { + t.Fatalf("unexpected value %q. Expecting %q", value, expectedValue) + } +} + func TestCookieSecureHttpOnly(t *testing.T) { + t.Parallel() + var c Cookie if err := c.Parse("foo=bar; HttpOnly; secure"); err != nil { @@ -28,6 +73,8 @@ } func TestCookieSecure(t *testing.T) { + t.Parallel() + var c Cookie if err := c.Parse("foo=bar; secure"); err != nil { @@ -44,7 +91,7 @@ if err := c.Parse("foo=bar"); err != nil { t.Fatalf("unexpected error: %s", err) } - if c.HTTPOnly() { + if c.Secure() { t.Fatalf("Unexpected secure flag set") } s = c.String() @@ -53,7 +100,124 @@ } } +func TestCookieSameSite(t *testing.T) { + t.Parallel() + + var c Cookie + + if err := c.Parse("foo=bar; samesite"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if c.SameSite() != CookieSameSiteDefaultMode { + t.Fatalf("SameSite must be set") + } + s := c.String() + if !strings.Contains(s, "; SameSite") { + t.Fatalf("missing SameSite flag in cookie %q", s) + } + + if err := c.Parse("foo=bar; samesite=lax"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if c.SameSite() != CookieSameSiteLaxMode { + t.Fatalf("SameSite Lax Mode must be set") + } + s = c.String() + if !strings.Contains(s, "; SameSite=Lax") { + t.Fatalf("missing SameSite flag in cookie %q", s) + } + + if err := c.Parse("foo=bar; samesite=strict"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if c.SameSite() != CookieSameSiteStrictMode { + t.Fatalf("SameSite Strict Mode must be set") + } + s = c.String() + if !strings.Contains(s, "; SameSite=Strict") { + t.Fatalf("missing SameSite flag in cookie %q", s) + } + + if err := c.Parse("foo=bar; samesite=none"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if c.SameSite() != CookieSameSiteNoneMode { + t.Fatalf("SameSite None Mode must be set") + } + s = c.String() + if !strings.Contains(s, "; SameSite=None") { + t.Fatalf("missing SameSite flag in cookie %q", s) + } + + if err := c.Parse("foo=bar"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + c.SetSameSite(CookieSameSiteNoneMode) + s = c.String() + if !strings.Contains(s, "; SameSite=None") { + t.Fatalf("missing SameSite flag in cookie %q", s) + } + if !strings.Contains(s, "; secure") { + t.Fatalf("missing Secure flag in cookie %q", s) + } + + if err := c.Parse("foo=bar"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if c.SameSite() != CookieSameSiteDisabled { + t.Fatalf("Unexpected SameSite flag set") + } + s = c.String() + if strings.Contains(s, "SameSite") { + t.Fatalf("unexpected SameSite flag in cookie %q", s) + } +} + +func TestCookieMaxAge(t *testing.T) { + t.Parallel() + + var c Cookie + + maxAge := 100 + if err := c.Parse("foo=bar; max-age=100"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if maxAge != c.MaxAge() { + t.Fatalf("max-age must be set") + } + s := c.String() + if !strings.Contains(s, "; max-age=100") { + t.Fatalf("missing max-age flag in cookie %q", s) + } + + if err := c.Parse("foo=bar; expires=Tue, 10 Nov 2009 23:00:00 GMT; max-age=100;"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if maxAge != c.MaxAge() { + t.Fatalf("max-age ignored") + } + s = c.String() + if s != "foo=bar; max-age=100" { + t.Fatalf("missing max-age in cookie %q", s) + } + + expires := time.Unix(100, 0) + c.SetExpire(expires) + s = c.String() + if s != "foo=bar; max-age=100" { + t.Fatalf("expires should be ignored due to max-age: %q", s) + } + + c.SetMaxAge(0) + s = c.String() + if s != "foo=bar; expires=Thu, 01 Jan 1970 00:01:40 GMT" { + t.Fatalf("missing expires %q", s) + } +} + func TestCookieHttpOnly(t *testing.T) { + t.Parallel() + var c Cookie if err := c.Parse("foo=bar; HttpOnly"); err != nil { @@ -80,10 +244,14 @@ } func TestCookieAcquireReleaseSequential(t *testing.T) { + t.Parallel() + testCookieAcquireRelease(t) } func TestCookieAcquireReleaseConcurrent(t *testing.T) { + t.Parallel() + ch := make(chan struct{}, 10) for i := 0; i < 10; i++ { go func() { @@ -138,10 +306,15 @@ } func TestCookieParse(t *testing.T) { + t.Parallel() + testCookieParse(t, "foo", "foo") testCookieParse(t, "foo=bar", "foo=bar") testCookieParse(t, "foo=", "foo=") - testCookieParse(t, "foo=bar; domain=aaa.com; path=/foo/bar", "foo=bar; domain=aaa.com; path=/foo/bar") + testCookieParse(t, `foo="bar"`, "foo=bar") + testCookieParse(t, `"foo"=bar`, `"foo"=bar`) + testCookieParse(t, "foo=bar; Domain=aaa.com; PATH=/foo/bar", "foo=bar; domain=aaa.com; path=/foo/bar") + testCookieParse(t, "foo=bar; max-age= 101 ; expires= Tue, 10 Nov 2009 23:00:00 GMT", "foo=bar; max-age=101") testCookieParse(t, " xxx = yyy ; path=/a/b;;;domain=foobar.com ; expires= Tue, 10 Nov 2009 23:00:00 GMT ; ;;", "xxx=yyy; expires=Tue, 10 Nov 2009 23:00:00 GMT; domain=foobar.com; path=/a/b") } @@ -153,16 +326,18 @@ } result := string(c.Cookie()) if result != expectedS { - t.Fatalf("unexpected cookies %q. Expected %q. Original %q", result, expectedS, s) + t.Fatalf("unexpected cookies %q. Expecting %q. Original %q", result, expectedS, s) } } func TestCookieAppendBytes(t *testing.T) { + t.Parallel() + c := &Cookie{} testCookieAppendBytes(t, c, "", "bar", "bar") testCookieAppendBytes(t, c, "foo", "", "foo=") - testCookieAppendBytes(t, c, "ффф", "12 лодлы", "%D1%84%D1%84%D1%84=12%20%D0%BB%D0%BE%D0%B4%D0%BB%D1%8B") + testCookieAppendBytes(t, c, "ффф", "12 лодлы", "ффф=12 лодлы") c.SetDomain("foobar.com") testCookieAppendBytes(t, c, "a", "b", "a=b; domain=foobar.com") @@ -179,11 +354,13 @@ c.SetValue(value) result := string(c.AppendBytes(nil)) if result != expectedS { - t.Fatalf("Unexpected cookie %q. Expected %q", result, expectedS) + t.Fatalf("Unexpected cookie %q. Expecting %q", result, expectedS) } } func TestParseRequestCookies(t *testing.T) { + t.Parallel() + testParseRequestCookies(t, "", "") testParseRequestCookies(t, "=", "") testParseRequestCookies(t, "foo", "foo") @@ -198,20 +375,23 @@ cookies := parseRequestCookies(nil, []byte(s)) ss := string(appendRequestCookieBytes(nil, cookies)) if ss != expectedS { - t.Fatalf("Unexpected cookies after parsing: %q. Expected %q. String to parse %q", ss, expectedS, s) + t.Fatalf("Unexpected cookies after parsing: %q. Expecting %q. String to parse %q", ss, expectedS, s) } } func TestAppendRequestCookieBytes(t *testing.T) { + t.Parallel() + testAppendRequestCookieBytes(t, "=", "") testAppendRequestCookieBytes(t, "foo=", "foo=") testAppendRequestCookieBytes(t, "=bar", "bar") - testAppendRequestCookieBytes(t, "привет=a b;c&s s=aaa", "%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82=a%20b%3Bc; s%20s=aaa") + testAppendRequestCookieBytes(t, "привет=a bc&s s=aaa", "привет=a bc; s s=aaa") } func testAppendRequestCookieBytes(t *testing.T, s, expectedS string) { - var cookies []argsKV - for _, ss := range strings.Split(s, "&") { + kvs := strings.Split(s, "&") + cookies := make([]argsKV, 0, len(kvs)) + for _, ss := range kvs { tmp := strings.SplitN(ss, "=", 2) if len(tmp) != 2 { t.Fatalf("Cannot find '=' in %q, part of %q", ss, s) @@ -225,10 +405,10 @@ prefix := "foobar" result := string(appendRequestCookieBytes([]byte(prefix), cookies)) if result[:len(prefix)] != prefix { - t.Fatalf("unexpected prefix %q. Expected %q for cookie %q", result[:len(prefix)], prefix, s) + t.Fatalf("unexpected prefix %q. Expecting %q for cookie %q", result[:len(prefix)], prefix, s) } result = result[len(prefix):] if result != expectedS { - t.Fatalf("Unexpected result %q. Expected %q for cookie %q", result, expectedS, s) + t.Fatalf("Unexpected result %q. Expecting %q for cookie %q", result, expectedS, s) } } diff -Nru golang-github-valyala-fasthttp-20160617/debian/changelog golang-github-valyala-fasthttp-1.31.0/debian/changelog --- golang-github-valyala-fasthttp-20160617/debian/changelog 2021-01-09 11:28:54.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/changelog 2022-03-16 13:54:45.000000000 +0000 @@ -1,9 +1,60 @@ -golang-github-valyala-fasthttp (20160617-2.1) unstable; urgency=medium +golang-github-valyala-fasthttp (1:1.31.0-3ubuntu1) jammy; urgency=medium - * Non maintainer upload by the Reproducible Builds team. - * No source change upload to rebuild on buildd with .buildinfo files. + * debian/patches/0001-bytesconv-add-appropriate-build-tags-for-s390x.patch: + Add appropriate build tags for s390x. This fixes an autopkgtest regression + on this architecture (LP: #1965134). - -- Holger Levsen Sat, 09 Jan 2021 12:28:54 +0100 + -- Nick Rosbrook Wed, 16 Mar 2022 09:54:45 -0400 + +golang-github-valyala-fasthttp (1:1.31.0-3) unstable; urgency=medium + + * Team upload. + * Replace timeout increasing patch with ignoring tests on arm CPU arches. + + -- Guillem Jover Mon, 29 Nov 2021 12:48:27 +0100 + +golang-github-valyala-fasthttp (1:1.31.0-2) unstable; urgency=medium + + * Team upload. + * Increase the timeouts in the test suite. + * Lower case first word on package synopsis. + * Mark -dev package as Multi-Arch: foreign. + + -- Guillem Jover Fri, 26 Nov 2021 19:25:49 +0100 + +golang-github-valyala-fasthttp (1:1.31.0-1) unstable; urgency=medium + + * Team upload. + + [ Debian Janitor ] + * Set upstream metadata fields: Bug-Database, Bug-Submit, Repository, + Repository-Browse. + + [ Alois Micard ] + * Fix various Lintian warnings. + * Remove dummy file. + + [ Guillem Jover ] + * New upstream release (bump epoch due to upstream versioning reset). + * Update .gitignore files. + * Update debian/watch file to format 4. + * Update gitlab-ci.yml from its upstream source. + * Update gbp packaging workflow. + * Switch to debhelper-compat level 13. + * Switch to dh-sequence-golang from dh-golang and --with=golang. + * Wrap and sort -sat. + * Remove redundant copyright and license name from License field. + * Switch Section from devel to golang. + * Switch to Standards-Version 4.6.0 (no changes needed). + * Use execute_after_ instead of overide_ and explicit dh_ call. + * Specify --builddirectory=_build. + * Update dependencies for 1.31.0 release. + * Update Go Team Maintainer address. + * Include README.md and stop ignoring tests results. + * Remove unused ${shlibs:Depends} from Depends field. + * Update copyright claims and years. + + -- Guillem Jover Mon, 22 Nov 2021 16:46:12 +0100 golang-github-valyala-fasthttp (20160617-2) unstable; urgency=medium diff -Nru golang-github-valyala-fasthttp-20160617/debian/compat golang-github-valyala-fasthttp-1.31.0/debian/compat --- golang-github-valyala-fasthttp-20160617/debian/compat 2017-12-20 03:12:38.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/compat 1970-01-01 00:00:00.000000000 +0000 @@ -1 +0,0 @@ -9 diff -Nru golang-github-valyala-fasthttp-20160617/debian/control golang-github-valyala-fasthttp-1.31.0/debian/control --- golang-github-valyala-fasthttp-20160617/debian/control 2018-02-16 23:18:37.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/control 2022-03-16 13:54:45.000000000 +0000 @@ -1,13 +1,23 @@ Source: golang-github-valyala-fasthttp -Section: devel +Section: golang Priority: optional -Maintainer: Debian Go Packaging Team -Uploaders: Nobuhiro Iwamatsu -Build-Depends: debhelper (>= 9), - dh-golang, - golang-any, - golang-github-klauspost-compress-dev -Standards-Version: 4.1.1 +Maintainer: Ubuntu Developers +XSBC-Original-Maintainer: Debian Go Packaging Team +Uploaders: + Nobuhiro Iwamatsu , +Build-Depends: + debhelper-compat (= 13), + dh-sequence-golang, + golang-any, + golang-github-andybalholm-brotli-dev, + golang-github-klauspost-compress-dev, + golang-github-valyala-bytebufferpool-dev, + golang-github-valyala-tcplisten-dev, + golang-golang-x-crypto-dev, + golang-golang-x-net-dev, + golang-golang-x-sys-dev, +Standards-Version: 4.6.0 +Rules-Requires-Root: no Homepage: https://github.com/valyala/fasthttp Vcs-Browser: https://salsa.debian.org/go-team/packages/golang-github-valyala-fasthttp Vcs-Git: https://salsa.debian.org/go-team/packages/golang-github-valyala-fasthttp.git @@ -16,10 +26,17 @@ Package: golang-github-valyala-fasthttp-dev Architecture: all -Depends: ${shlibs:Depends}, - ${misc:Depends}, - golang-github-klauspost-compress-dev -Description: Fast HTTP library for Go +Multi-Arch: foreign +Depends: + golang-github-andybalholm-brotli-dev, + golang-github-klauspost-compress-dev, + golang-github-valyala-bytebufferpool-dev, + golang-github-valyala-tcplisten-dev, + golang-golang-x-crypto-dev, + golang-golang-x-net-dev, + golang-golang-x-sys-dev, + ${misc:Depends}, +Description: fast HTTP library for Go The fasthttp library provides fast HTTP server and client API. . This tuned for high performance, and zero memory allocations in diff -Nru golang-github-valyala-fasthttp-20160617/debian/copyright golang-github-valyala-fasthttp-1.31.0/debian/copyright --- golang-github-valyala-fasthttp-20160617/debian/copyright 2018-02-27 06:52:25.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/copyright 2021-11-22 15:57:58.000000000 +0000 @@ -1,9 +1,9 @@ Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ -Upstream-Name: fasthttp Source: https://github.com/valyala/fasthttp +Upstream-Name: fasthttp Files: * -Copyright: 2015-2016 Aliaksandr Valialkin, VertaMedia +Copyright: 2015-2021 Aliaksandr Valialkin, VertaMedia License: Expat Files: reuseport/* @@ -12,14 +12,11 @@ Files: debian/* Copyright: 2017 Nobuhiro Iwamatsu + 2021 Sipwise GmbH, Austria License: Expat Comment: Debian packaging is licensed under the same terms as upstream License: Expat - The MIT License (MIT) - . - Copyright (c) 2015-2016 Aliaksandr Valialkin, VertaMedia - . Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights diff -Nru golang-github-valyala-fasthttp-20160617/debian/gbp.conf golang-github-valyala-fasthttp-1.31.0/debian/gbp.conf --- golang-github-valyala-fasthttp-20160617/debian/gbp.conf 2017-12-15 03:25:23.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/gbp.conf 2021-11-22 15:45:23.000000000 +0000 @@ -1,2 +1,3 @@ [DEFAULT] -pristine-tar = True +debian-branch = debian/sid +dist = DEP14 diff -Nru golang-github-valyala-fasthttp-20160617/debian/gitlab-ci.yml golang-github-valyala-fasthttp-1.31.0/debian/gitlab-ci.yml --- golang-github-valyala-fasthttp-20160617/debian/gitlab-ci.yml 2018-02-27 23:41:07.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/gitlab-ci.yml 2021-11-22 15:45:23.000000000 +0000 @@ -1,9 +1,7 @@ - # auto-generated, DO NOT MODIFY. # The authoritative copy of this file lives at: -# https://salsa.debian.org/go-team/ci/blob/master/cmd/ci/gitlabciyml.go +# https://salsa.debian.org/go-team/ci/blob/master/config/gitlabciyml.go -# TODO: publish under debian-go-team/ci image: stapelberg/ci2 test_the_archive: diff -Nru golang-github-valyala-fasthttp-20160617/debian/patches/0001-bytesconv-add-appropriate-build-tags-for-s390x.patch golang-github-valyala-fasthttp-1.31.0/debian/patches/0001-bytesconv-add-appropriate-build-tags-for-s390x.patch --- golang-github-valyala-fasthttp-20160617/debian/patches/0001-bytesconv-add-appropriate-build-tags-for-s390x.patch 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/patches/0001-bytesconv-add-appropriate-build-tags-for-s390x.patch 2022-03-16 13:54:45.000000000 +0000 @@ -0,0 +1,73 @@ +Description: Add appropriate build tags for s390x + The bytesconv 32-bit tests fail on s390x, because it is a 64-bit + architecture. Add the appropriate build flags so that 32-bit tests do + not run on this architecture. +Author: Nick Rosbrook +Forwarded: https://github.com/valyala/fasthttp/pull/1250 +Last-Update: 2022-03-16 +--- +From d6c6e4a7cc9c17158dc2c93090e5b7d26ca42e15 Mon Sep 17 00:00:00 2001 +From: Nick Rosbrook +Date: Wed, 16 Mar 2022 09:41:03 -0400 +Subject: [PATCH] bytesconv: add appropriate build tags for s390x + +The bytesconv 32-bit tests fail on s390x, because it is a 64-bit +architecture. Add the appropriate build flags so that 32-bit tests do +not run on this architecture. +--- + bytesconv_32.go | 4 ++-- + bytesconv_32_test.go | 4 ++-- + bytesconv_64.go | 4 ++-- + bytesconv_64_test.go | 4 ++-- + 4 files changed, 8 insertions(+), 8 deletions(-) +diff --git a/bytesconv_32.go b/bytesconv_32.go +index 6a6fec2..b574883 100644 +--- a/bytesconv_32.go ++++ b/bytesconv_32.go +@@ -1,5 +1,5 @@ +-//go:build !amd64 && !arm64 && !ppc64 && !ppc64le +-// +build !amd64,!arm64,!ppc64,!ppc64le ++//go:build !amd64 && !arm64 && !ppc64 && !ppc64le && !s390x ++// +build !amd64,!arm64,!ppc64,!ppc64le,!s390x + + package fasthttp + +diff --git a/bytesconv_32_test.go b/bytesconv_32_test.go +index cec5aa9..3f5d5de 100644 +--- a/bytesconv_32_test.go ++++ b/bytesconv_32_test.go +@@ -1,5 +1,5 @@ +-//go:build !amd64 && !arm64 && !ppc64 && !ppc64le +-// +build !amd64,!arm64,!ppc64,!ppc64le ++//go:build !amd64 && !arm64 && !ppc64 && !ppc64le && !s390x ++// +build !amd64,!arm64,!ppc64,!ppc64le,!s390x + + package fasthttp + +diff --git a/bytesconv_64.go b/bytesconv_64.go +index 1300d5a..94d0ec6 100644 +--- a/bytesconv_64.go ++++ b/bytesconv_64.go +@@ -1,5 +1,5 @@ +-//go:build amd64 || arm64 || ppc64 || ppc64le +-// +build amd64 arm64 ppc64 ppc64le ++//go:build amd64 || arm64 || ppc64 || ppc64le || s390x ++// +build amd64 arm64 ppc64 ppc64le s390x + + package fasthttp + +diff --git a/bytesconv_64_test.go b/bytesconv_64_test.go +index 5351591..0689809 100644 +--- a/bytesconv_64_test.go ++++ b/bytesconv_64_test.go +@@ -1,5 +1,5 @@ +-//go:build amd64 || arm64 || ppc64 || ppc64le +-// +build amd64 arm64 ppc64 ppc64le ++//go:build amd64 || arm64 || ppc64 || ppc64le || s390x ++// +build amd64 arm64 ppc64 ppc64le s390x + + package fasthttp + +-- +2.32.0 + diff -Nru golang-github-valyala-fasthttp-20160617/debian/patches/series golang-github-valyala-fasthttp-1.31.0/debian/patches/series --- golang-github-valyala-fasthttp-20160617/debian/patches/series 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/patches/series 2022-03-16 13:54:45.000000000 +0000 @@ -0,0 +1 @@ +0001-bytesconv-add-appropriate-build-tags-for-s390x.patch diff -Nru golang-github-valyala-fasthttp-20160617/debian/rules golang-github-valyala-fasthttp-1.31.0/debian/rules --- golang-github-valyala-fasthttp-20160617/debian/rules 2018-02-27 06:56:05.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/rules 2021-11-29 11:48:27.000000000 +0000 @@ -1,12 +1,20 @@ #!/usr/bin/make -f +include /usr/share/dpkg/architecture.mk + +export DH_GOLANG_INSTALL_EXTRA := \ + README.md \ + # EOL + %: - dh $@ --buildsystem=golang --with=golang + dh $@ --buildsystem=golang --builddirectory=_build +execute_after_dh_auto_install: + # Remove binary. + rm -rf debian/golang-github-valyala-fasthttp-dev/usr/bin + +# Tests on arm-based architectures timeout. +ifeq ($(DEB_HOST_ARCH_CPU),arm) override_dh_auto_test: -dh_auto_test - -override_dh_auto_install: - dh_auto_install - # remove binary - rm -rf ./debian/golang-github-valyala-fasthttp-dev/usr/bin +endif diff -Nru golang-github-valyala-fasthttp-20160617/debian/upstream/metadata golang-github-valyala-fasthttp-1.31.0/debian/upstream/metadata --- golang-github-valyala-fasthttp-20160617/debian/upstream/metadata 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/upstream/metadata 2021-11-18 22:21:21.000000000 +0000 @@ -0,0 +1,5 @@ +--- +Bug-Database: https://github.com/valyala/fasthttp/issues +Bug-Submit: https://github.com/valyala/fasthttp/issues/new +Repository: https://github.com/valyala/fasthttp.git +Repository-Browse: https://github.com/valyala/fasthttp diff -Nru golang-github-valyala-fasthttp-20160617/debian/watch golang-github-valyala-fasthttp-1.31.0/debian/watch --- golang-github-valyala-fasthttp-20160617/debian/watch 2017-12-15 03:25:23.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/debian/watch 2021-11-22 15:45:23.000000000 +0000 @@ -1,4 +1,4 @@ -version=3 -opts=filenamemangle=s/.+\/v?(\d\S*)\.tar\.gz/golang-github-valyala-fasthttp-\$1\.tar\.gz/,\ -uversionmangle=s/(\d)[_\.\-\+]?(RC|rc|pre|dev|beta|alpha)[.]?(\d*)$/\$1~\$2\$3/ \ +version=4 +opts="filenamemangle=s%(?:.*?)?v?(\d[\d.]*)\.tar\.gz%golang-github-valyala-fasthttp-$1.tar.gz%,\ + uversionmangle=s/(\d)[_\.\-\+]?(RC|rc|pre|dev|beta|alpha)[.]?(\d*)$/\$1~\$2\$3/" \ https://github.com/valyala/fasthttp/tags .*/v?(\d\S*)\.tar\.gz diff -Nru golang-github-valyala-fasthttp-20160617/doc.go golang-github-valyala-fasthttp-1.31.0/doc.go --- golang-github-valyala-fasthttp-20160617/doc.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/doc.go 2021-10-09 18:39:05.000000000 +0000 @@ -7,9 +7,6 @@ concurrent keep-alive connections on modern hardware. * Optimized for low memory usage. * Easy 'Connection: Upgrade' support via RequestCtx.Hijack. - * Server supports requests' pipelining. Multiple requests may be read from - a single network packet and multiple responses may be sent in a single - network packet. This may be useful for highly loaded REST services. * Server provides the following anti-DoS limits: * The number of concurrent connections. diff -Nru golang-github-valyala-fasthttp-20160617/examples/fileserver/fileserver.go golang-github-valyala-fasthttp-1.31.0/examples/fileserver/fileserver.go --- golang-github-valyala-fasthttp-20160617/examples/fileserver/fileserver.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/examples/fileserver/fileserver.go 2021-10-09 18:39:05.000000000 +0000 @@ -17,11 +17,11 @@ addr = flag.String("addr", "localhost:8080", "TCP address to listen to") addrTLS = flag.String("addrTLS", "", "TCP address to listen to TLS (aka SSL or HTTPS) requests. Leave empty for disabling TLS") byteRange = flag.Bool("byteRange", false, "Enables byte range requests if set to true") - certFile = flag.String("certFile", "./ssl-cert-snakeoil.pem", "Path to TLS certificate file") + certFile = flag.String("certFile", "./ssl-cert.pem", "Path to TLS certificate file") compress = flag.Bool("compress", false, "Enables transparent response compression if set to true") dir = flag.String("dir", "/usr/share/nginx/html", "Directory to serve static files from") generateIndexPages = flag.Bool("generateIndexPages", true, "Whether to generate directory index pages") - keyFile = flag.String("keyFile", "./ssl-cert-snakeoil.key", "Path to TLS key file") + keyFile = flag.String("keyFile", "./ssl-cert.key", "Path to TLS key file") vhost = flag.Bool("vhost", false, "Enables virtual hosting by prepending the requested path with the requested hostname") ) diff -Nru golang-github-valyala-fasthttp-20160617/examples/letsencrypt/letsencryptserver.go golang-github-valyala-fasthttp-1.31.0/examples/letsencrypt/letsencryptserver.go --- golang-github-valyala-fasthttp-20160617/examples/letsencrypt/letsencryptserver.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/examples/letsencrypt/letsencryptserver.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,41 @@ +package main + +import ( + "crypto/tls" + "net" + + "github.com/valyala/fasthttp" + "golang.org/x/crypto/acme" + "golang.org/x/crypto/acme/autocert" +) + +func requestHandler(ctx *fasthttp.RequestCtx) { + ctx.SetBodyString("hello from https!") +} + +func main() { + m := &autocert.Manager{ + Prompt: autocert.AcceptTOS, + HostPolicy: autocert.HostWhitelist("example.com"), // Replace with your domain. + Cache: autocert.DirCache("./certs"), + } + + cfg := &tls.Config{ + GetCertificate: m.GetCertificate, + NextProtos: []string{ + "http/1.1", acme.ALPNProto, + }, + } + + // Let's Encrypt tls-alpn-01 only works on port 443. + ln, err := net.Listen("tcp4", "0.0.0.0:443") /* #nosec G102 */ + if err != nil { + panic(err) + } + + lnTls := tls.NewListener(ln, cfg) + + if err := fasthttp.Serve(lnTls, requestHandler); err != nil { + panic(err) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/examples/multidomain/Makefile golang-github-valyala-fasthttp-1.31.0/examples/multidomain/Makefile --- golang-github-valyala-fasthttp-20160617/examples/multidomain/Makefile 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/examples/multidomain/Makefile 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,6 @@ +writer: clean + go get -u github.com/valyala/fasthttp + go build + +clean: + rm -f multidomain diff -Nru golang-github-valyala-fasthttp-20160617/examples/multidomain/multidomain.go golang-github-valyala-fasthttp-1.31.0/examples/multidomain/multidomain.go --- golang-github-valyala-fasthttp-20160617/examples/multidomain/multidomain.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/examples/multidomain/multidomain.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,63 @@ +package main + +import ( + "fmt" + + "github.com/valyala/fasthttp" +) + +var domains = make(map[string]fasthttp.RequestHandler) + +func main() { + server := &fasthttp.Server{ + // You can check the access using openssl command: + // $ openssl s_client -connect localhost:8080 << EOF + // > GET / + // > Host: localhost + // > EOF + // + // $ openssl s_client -connect localhost:8080 << EOF + // > GET / + // > Host: 127.0.0.1:8080 + // > EOF + // + Handler: func(ctx *fasthttp.RequestCtx) { + h, ok := domains[string(ctx.Host())] + if !ok { + ctx.NotFound() + return + } + h(ctx) + }, + } + + // preparing first host + cert, priv, err := fasthttp.GenerateTestCertificate("localhost:8080") + if err != nil { + panic(err) + } + domains["localhost:8080"] = func(ctx *fasthttp.RequestCtx) { + ctx.Write([]byte("You are accessing to localhost:8080\n")) + } + + err = server.AppendCertEmbed(cert, priv) + if err != nil { + panic(err) + } + + // preparing second host + cert, priv, err = fasthttp.GenerateTestCertificate("127.0.0.1") + if err != nil { + panic(err) + } + domains["127.0.0.1:8080"] = func(ctx *fasthttp.RequestCtx) { + ctx.Write([]byte("You are accessing to 127.0.0.1:8080\n")) + } + + err = server.AppendCertEmbed(cert, priv) + if err != nil { + panic(err) + } + + fmt.Println(server.ListenAndServeTLS(":8080", "", "")) +} diff -Nru golang-github-valyala-fasthttp-20160617/examples/multidomain/README.md golang-github-valyala-fasthttp-1.31.0/examples/multidomain/README.md --- golang-github-valyala-fasthttp-20160617/examples/multidomain/README.md 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/examples/multidomain/README.md 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,15 @@ +# Multidomain using SSL certs example + +* Prints two messages depending on visited host. + +# How to build + +``` +make +``` + +# How to run + +``` +./multidomain +``` diff -Nru golang-github-valyala-fasthttp-20160617/expvarhandler/expvar.go golang-github-valyala-fasthttp-1.31.0/expvarhandler/expvar.go --- golang-github-valyala-fasthttp-20160617/expvarhandler/expvar.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/expvarhandler/expvar.go 2021-10-09 18:39:05.000000000 +0000 @@ -13,6 +13,8 @@ var ( expvarHandlerCalls = expvar.NewInt("expvarHandlerCalls") expvarRegexpErrors = expvar.NewInt("expvarRegexpErrors") + + defaultRE = regexp.MustCompile(".") ) // ExpvarHandler dumps json representation of expvars to http response. @@ -36,10 +38,10 @@ fmt.Fprintf(ctx, "{\n") first := true expvar.Do(func(kv expvar.KeyValue) { - if !first { - fmt.Fprintf(ctx, ",\n") - } if r.MatchString(kv.Key) { + if !first { + fmt.Fprintf(ctx, ",\n") + } first = false fmt.Fprintf(ctx, "\t%q: %s", kv.Key, kv.Value) } @@ -52,7 +54,7 @@ func getExpvarRegexp(ctx *fasthttp.RequestCtx) (*regexp.Regexp, error) { r := string(ctx.QueryArgs().Peek("r")) if len(r) == 0 { - r = "." + return defaultRE, nil } rr, err := regexp.Compile(r) if err != nil { diff -Nru golang-github-valyala-fasthttp-20160617/expvarhandler/expvar_test.go golang-github-valyala-fasthttp-1.31.0/expvarhandler/expvar_test.go --- golang-github-valyala-fasthttp-20160617/expvarhandler/expvar_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/expvarhandler/expvar_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -10,6 +10,8 @@ ) func TestExpvarHandlerBasic(t *testing.T) { + t.Parallel() + expvar.Publish("customVar", expvar.Func(func() interface{} { return "foobar" })) diff -Nru golang-github-valyala-fasthttp-20160617/fasthttpadaptor/adaptor.go golang-github-valyala-fasthttp-1.31.0/fasthttpadaptor/adaptor.go --- golang-github-valyala-fasthttp-20160617/fasthttpadaptor/adaptor.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttpadaptor/adaptor.go 2021-10-09 18:39:05.000000000 +0000 @@ -3,9 +3,7 @@ package fasthttpadaptor import ( - "io" "net/http" - "net/url" "github.com/valyala/fasthttp" ) @@ -49,60 +47,38 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler { return func(ctx *fasthttp.RequestCtx) { var r http.Request - - body := ctx.PostBody() - r.Method = string(ctx.Method()) - r.Proto = "HTTP/1.1" - r.ProtoMajor = 1 - r.ProtoMinor = 1 - r.RequestURI = string(ctx.RequestURI()) - r.ContentLength = int64(len(body)) - r.Host = string(ctx.Host()) - r.RemoteAddr = ctx.RemoteAddr().String() - - hdr := make(http.Header) - ctx.Request.Header.VisitAll(func(k, v []byte) { - hdr.Set(string(k), string(v)) - }) - r.Header = hdr - r.Body = &netHTTPBody{body} - rURL, err := url.ParseRequestURI(r.RequestURI) - if err != nil { + if err := ConvertRequest(ctx, &r, true); err != nil { ctx.Logger().Printf("cannot parse requestURI %q: %s", r.RequestURI, err) ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) return } - r.URL = rURL var w netHTTPResponseWriter - h.ServeHTTP(&w, &r) + h.ServeHTTP(&w, r.WithContext(ctx)) ctx.SetStatusCode(w.StatusCode()) + haveContentType := false for k, vv := range w.Header() { + if k == fasthttp.HeaderContentType { + haveContentType = true + } + for _, v := range vv { - ctx.Response.Header.Set(k, v) + ctx.Response.Header.Add(k, v) } } - ctx.Write(w.body) - } -} - -type netHTTPBody struct { - b []byte -} - -func (r *netHTTPBody) Read(p []byte) (int, error) { - if len(r.b) == 0 { - return 0, io.EOF + if !haveContentType { + // From net/http.ResponseWriter.Write: + // If the Header does not contain a Content-Type line, Write adds a Content-Type set + // to the result of passing the initial 512 bytes of written data to DetectContentType. + l := 512 + if len(w.body) < 512 { + l = len(w.body) + } + ctx.Response.Header.Set(fasthttp.HeaderContentType, http.DetectContentType(w.body[:l])) + } + ctx.Write(w.body) //nolint:errcheck } - n := copy(p, r.b) - r.b = r.b[n:] - return n, nil -} - -func (r *netHTTPBody) Close() error { - r.b = r.b[:0] - return nil } type netHTTPResponseWriter struct { diff -Nru golang-github-valyala-fasthttp-20160617/fasthttpadaptor/adaptor_test.go golang-github-valyala-fasthttp-1.31.0/fasthttpadaptor/adaptor_test.go --- golang-github-valyala-fasthttp-20160617/fasthttpadaptor/adaptor_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttpadaptor/adaptor_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -13,7 +13,9 @@ ) func TestNewFastHTTPHandler(t *testing.T) { - expectedMethod := "POST" + t.Parallel() + + expectedMethod := fasthttp.MethodPost expectedProto := "HTTP/1.1" expectedProtoMajor := 1 expectedProtoMinor := 1 @@ -31,6 +33,8 @@ if err != nil { t.Fatalf("unexpected error: %s", err) } + expectedContextKey := "contextKey" + expectedContextValue := "contextValue" callsCount := 0 nethttpH := func(w http.ResponseWriter, r *http.Request) { @@ -53,6 +57,9 @@ if r.ContentLength != int64(expectedContentLength) { t.Fatalf("unexpected contentLength %d. Expecting %d", r.ContentLength, expectedContentLength) } + if len(r.TransferEncoding) != 0 { + t.Fatalf("unexpected transferEncoding %q. Expecting []", r.TransferEncoding) + } if r.Host != expectedHost { t.Fatalf("unexpected host %q. Expecting %q", r.Host, expectedHost) } @@ -70,6 +77,9 @@ if !reflect.DeepEqual(r.URL, expectedURL) { t.Fatalf("unexpected URL: %#v. Expecting %#v", r.URL, expectedURL) } + if r.Context().Value(expectedContextKey) != expectedContextValue { + t.Fatalf("unexpected context value for key %q. Expecting %q", expectedContextKey, expectedContextValue) + } for k, expectedV := range expectedHeader { v := r.Header.Get(k) @@ -84,6 +94,7 @@ fmt.Fprintf(w, "request body is %q", body) } fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH)) + fasthttpH = setContextValueMiddleware(fasthttpH, expectedContextKey, expectedContextValue) var ctx fasthttp.RequestCtx var req fasthttp.Request @@ -91,7 +102,7 @@ req.Header.SetMethod(expectedMethod) req.SetRequestURI(expectedRequestURI) req.Header.SetHost(expectedHost) - req.BodyWriter().Write([]byte(expectedBody)) + req.BodyWriter().Write([]byte(expectedBody)) // nolint:errcheck for k, v := range expectedHeader { req.Header.Set(k, v) } @@ -123,3 +134,39 @@ t.Fatalf("unexpected response body %q. Expecting %q", resp.Body(), expectedResponseBody) } } + +func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value interface{}) fasthttp.RequestHandler { + return func(ctx *fasthttp.RequestCtx) { + ctx.SetUserValue(key, value) + next(ctx) + } +} + +func TestContentType(t *testing.T) { + t.Parallel() + + nethttpH := func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("")) //nolint:errcheck + } + fasthttpH := NewFastHTTPHandler(http.HandlerFunc(nethttpH)) + + var ctx fasthttp.RequestCtx + var req fasthttp.Request + + req.SetRequestURI("http://example.com") + + remoteAddr, err := net.ResolveTCPAddr("tcp", "1.2.3.4:80") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx.Init(&req, remoteAddr, nil) + + fasthttpH(&ctx) + + resp := &ctx.Response + got := string(resp.Header.Peek("Content-Type")) + expected := "text/html; charset=utf-8" + if got != expected { + t.Errorf("expected %q got %q", expected, got) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/fasthttpadaptor/request.go golang-github-valyala-fasthttp-1.31.0/fasthttpadaptor/request.go --- golang-github-valyala-fasthttp-20160617/fasthttpadaptor/request.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttpadaptor/request.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,59 @@ +package fasthttpadaptor + +import ( + "bytes" + "io/ioutil" + "net/http" + "net/url" + + "github.com/valyala/fasthttp" +) + +// ConvertRequest convert a fasthttp.Request to an http.Request +// forServer should be set to true when the http.Request is going to passed to a http.Handler. +func ConvertRequest(ctx *fasthttp.RequestCtx, r *http.Request, forServer bool) error { + body := ctx.PostBody() + strRequestURI := string(ctx.RequestURI()) + + rURL, err := url.ParseRequestURI(strRequestURI) + if err != nil { + return err + } + + r.Method = string(ctx.Method()) + r.Proto = "HTTP/1.1" + r.ProtoMajor = 1 + r.ProtoMinor = 1 + r.ContentLength = int64(len(body)) + r.RemoteAddr = ctx.RemoteAddr().String() + r.Host = string(ctx.Host()) + r.TLS = ctx.TLSConnectionState() + r.Body = ioutil.NopCloser(bytes.NewReader(body)) + r.URL = rURL + + if forServer { + r.RequestURI = strRequestURI + } + + if r.Header == nil { + r.Header = make(http.Header) + } else if len(r.Header) > 0 { + for k := range r.Header { + delete(r.Header, k) + } + } + + ctx.Request.Header.VisitAll(func(k, v []byte) { + sk := string(k) + sv := string(v) + + switch sk { + case "Transfer-Encoding": + r.TransferEncoding = append(r.TransferEncoding, sv) + default: + r.Header.Set(sk, sv) + } + }) + + return nil +} diff -Nru golang-github-valyala-fasthttp-20160617/fasthttpproxy/http.go golang-github-valyala-fasthttp-1.31.0/fasthttpproxy/http.go --- golang-github-valyala-fasthttp-20160617/fasthttpproxy/http.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttpproxy/http.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,77 @@ +package fasthttpproxy + +import ( + "bufio" + "encoding/base64" + "fmt" + "net" + "strings" + "time" + + "github.com/valyala/fasthttp" +) + +// FasthttpHTTPDialer returns a fasthttp.DialFunc that dials using +// the provided HTTP proxy. +// +// Example usage: +// c := &fasthttp.Client{ +// Dial: fasthttpproxy.FasthttpHTTPDialer("username:password@localhost:9050"), +// } +func FasthttpHTTPDialer(proxy string) fasthttp.DialFunc { + return FasthttpHTTPDialerTimeout(proxy, 0) +} + +// FasthttpHTTPDialerTimeout returns a fasthttp.DialFunc that dials using +// the provided HTTP proxy using the given timeout. +// +// Example usage: +// c := &fasthttp.Client{ +// Dial: fasthttpproxy.FasthttpHTTPDialerTimeout("username:password@localhost:9050", time.Second * 2), +// } +func FasthttpHTTPDialerTimeout(proxy string, timeout time.Duration) fasthttp.DialFunc { + var auth string + if strings.Contains(proxy, "@") { + split := strings.Split(proxy, "@") + auth = base64.StdEncoding.EncodeToString([]byte(split[0])) + proxy = split[1] + } + + return func(addr string) (net.Conn, error) { + var conn net.Conn + var err error + if timeout == 0 { + conn, err = fasthttp.Dial(proxy) + } else { + conn, err = fasthttp.DialTimeout(proxy, timeout) + } + if err != nil { + return nil, err + } + + req := "CONNECT " + addr + " HTTP/1.1\r\n" + if auth != "" { + req += "Proxy-Authorization: Basic " + auth + "\r\n" + } + req += "\r\n" + + if _, err := conn.Write([]byte(req)); err != nil { + return nil, err + } + + res := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(res) + + res.SkipBody = true + + if err := res.Read(bufio.NewReader(conn)); err != nil { + conn.Close() + return nil, err + } + if res.Header.StatusCode() != 200 { + conn.Close() + return nil, fmt.Errorf("could not connect to proxy: %s status code: %d", proxy, res.Header.StatusCode()) + } + return conn, nil + } +} diff -Nru golang-github-valyala-fasthttp-20160617/fasthttpproxy/proxy_env.go golang-github-valyala-fasthttp-1.31.0/fasthttpproxy/proxy_env.go --- golang-github-valyala-fasthttp-20160617/fasthttpproxy/proxy_env.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttpproxy/proxy_env.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,125 @@ +package fasthttpproxy + +import ( + "bufio" + "encoding/base64" + "fmt" + "net" + "net/url" + "sync/atomic" + "time" + + "golang.org/x/net/http/httpproxy" + + "github.com/valyala/fasthttp" +) + +const ( + httpsScheme = "https" + httpScheme = "http" + tlsPort = "443" +) + +// FasthttpProxyHTTPDialer returns a fasthttp.DialFunc that dials using +// the the env(HTTP_PROXY, HTTPS_PROXY and NO_PROXY) configured HTTP proxy. +// +// Example usage: +// c := &fasthttp.Client{ +// Dial: FasthttpProxyHTTPDialer(), +// } +func FasthttpProxyHTTPDialer() fasthttp.DialFunc { + return FasthttpProxyHTTPDialerTimeout(0) +} + +// FasthttpProxyHTTPDialer returns a fasthttp.DialFunc that dials using +// the env(HTTP_PROXY, HTTPS_PROXY and NO_PROXY) configured HTTP proxy using the given timeout. +// +// Example usage: +// c := &fasthttp.Client{ +// Dial: FasthttpProxyHTTPDialerTimeout(time.Second * 2), +// } +func FasthttpProxyHTTPDialerTimeout(timeout time.Duration) fasthttp.DialFunc { + proxier := httpproxy.FromEnvironment().ProxyFunc() + + // encoded auth barrier for http and https proxy. + authHTTPStorage := &atomic.Value{} + authHTTPSStorage := &atomic.Value{} + + return func(addr string) (net.Conn, error) { + + port, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("unexpected addr format: %v", err) + } + + reqURL := &url.URL{Host: addr, Scheme: httpScheme} + if port == tlsPort { + reqURL.Scheme = httpsScheme + } + proxyURL, err := proxier(reqURL) + if err != nil { + return nil, err + } + + if proxyURL == nil { + if timeout == 0 { + return fasthttp.Dial(addr) + } + return fasthttp.DialTimeout(addr, timeout) + } + + var conn net.Conn + if timeout == 0 { + conn, err = fasthttp.Dial(proxyURL.Host) + } else { + conn, err = fasthttp.DialTimeout(proxyURL.Host, timeout) + } + if err != nil { + return nil, err + } + + req := "CONNECT " + addr + " HTTP/1.1\r\n" + + if proxyURL.User != nil { + authBarrierStorage := authHTTPStorage + if port == tlsPort { + authBarrierStorage = authHTTPSStorage + } + + auth := authBarrierStorage.Load() + if auth == nil { + authBarrier := base64.StdEncoding.EncodeToString([]byte(proxyURL.User.String())) + auth := &authBarrier + authBarrierStorage.Store(auth) + } + + req += "Proxy-Authorization: Basic " + *auth.(*string) + "\r\n" + } + req += "\r\n" + + if _, err := conn.Write([]byte(req)); err != nil { + return nil, err + } + + res := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(res) + + res.SkipBody = true + + if err := res.Read(bufio.NewReader(conn)); err != nil { + if connErr := conn.Close(); connErr != nil { + return nil, fmt.Errorf("conn close err %v followed by read conn err %v", connErr, err) + } + return nil, err + } + if res.Header.StatusCode() != 200 { + if connErr := conn.Close(); connErr != nil { + return nil, fmt.Errorf( + "conn close err %v followed by connect to proxy: code: %d body %s", + connErr, res.StatusCode(), string(res.Body())) + } + return nil, fmt.Errorf("could not connect to proxy: code: %d body %s", res.StatusCode(), string(res.Body())) + } + return conn, nil + } +} diff -Nru golang-github-valyala-fasthttp-20160617/fasthttpproxy/socks5.go golang-github-valyala-fasthttp-1.31.0/fasthttpproxy/socks5.go --- golang-github-valyala-fasthttp-20160617/fasthttpproxy/socks5.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttpproxy/socks5.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,38 @@ +package fasthttpproxy + +import ( + "net" + "net/url" + + "github.com/valyala/fasthttp" + "golang.org/x/net/proxy" +) + +// FasthttpSocksDialer returns a fasthttp.DialFunc that dials using +// the provided SOCKS5 proxy. +// +// Example usage: +// c := &fasthttp.Client{ +// Dial: fasthttpproxy.FasthttpSocksDialer("socks5://localhost:9050"), +// } +func FasthttpSocksDialer(proxyAddr string) fasthttp.DialFunc { + var ( + u *url.URL + err error + dialer proxy.Dialer + ) + if u, err = url.Parse(proxyAddr); err == nil { + dialer, err = proxy.FromURL(u, proxy.Direct) + } + // It would be nice if we could return the error here. But we can't + // change our API so just keep returning it in the returned Dial function. + // Besides the implementation of proxy.SOCKS5() at the time of writing this + // will always return nil as error. + + return func(addr string) (net.Conn, error) { + if err != nil { + return nil, err + } + return dialer.Dial("tcp", addr) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/ecdsa.key golang-github-valyala-fasthttp-1.31.0/fasthttputil/ecdsa.key --- golang-github-valyala-fasthttp-20160617/fasthttputil/ecdsa.key 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/ecdsa.key 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIBpQbZ6a5jL1Yh4wdP6yZk4MKjYWArD/QOLENFw8vbELoAoGCCqGSM49 +AwEHoUQDQgAEKQCZWgE2IBhb47ot8MIs1D4KSisHYlZ41IWyeutpjb0fjwwIhimh +pl1Qld1/d2j3Z3vVyfa5yD+ncV7qCFZuSg== +-----END EC PRIVATE KEY----- diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/ecdsa.pem golang-github-valyala-fasthttp-1.31.0/fasthttputil/ecdsa.pem --- golang-github-valyala-fasthttp-20160617/fasthttputil/ecdsa.pem 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/ecdsa.pem 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,10 @@ +-----BEGIN CERTIFICATE----- +MIIBbTCCAROgAwIBAgIQPo718S+K+G7hc1SgTEU4QDAKBggqhkjOPQQDAjASMRAw +DgYDVQQKEwdBY21lIENvMB4XDTE3MDQyMDIxMDExNFoXDTE4MDQyMDIxMDExNFow +EjEQMA4GA1UEChMHQWNtZSBDbzBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABCkA +mVoBNiAYW+O6LfDCLNQ+CkorB2JWeNSFsnrraY29H48MCIYpoaZdUJXdf3do92d7 +1cn2ucg/p3Fe6ghWbkqjSzBJMA4GA1UdDwEB/wQEAwIFoDATBgNVHSUEDDAKBggr +BgEFBQcDATAMBgNVHRMBAf8EAjAAMBQGA1UdEQQNMAuCCWxvY2FsaG9zdDAKBggq +hkjOPQQDAgNIADBFAiEAoLAIQkvSuIcHUqyWroA6yWYw2fznlRH/uO9/hMCxUCEC +IClRYb/5O9eD/Eq/ozPnwNpsQHOeYefEhadJ/P82y0lG +-----END CERTIFICATE----- diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/inmemory_listener.go golang-github-valyala-fasthttp-1.31.0/fasthttputil/inmemory_listener.go --- golang-github-valyala-fasthttp-20160617/fasthttputil/inmemory_listener.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/inmemory_listener.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,25 +1,33 @@ package fasthttputil import ( - "fmt" + "errors" "net" "sync" ) +// ErrInmemoryListenerClosed indicates that the InmemoryListener is already closed. +var ErrInmemoryListenerClosed = errors.New("InmemoryListener is already closed: use of closed network connection") + // InmemoryListener provides in-memory dialer<->net.Listener implementation. // -// It may be used either for fast in-process client<->server communcations +// It may be used either for fast in-process client<->server communications // without network stack overhead or for client<->server tests. type InmemoryListener struct { lock sync.Mutex closed bool - conns chan net.Conn + conns chan acceptConn +} + +type acceptConn struct { + conn net.Conn + accepted chan struct{} } // NewInmemoryListener returns new in-memory dialer<->net.Listener. func NewInmemoryListener() *InmemoryListener { return &InmemoryListener{ - conns: make(chan net.Conn, 1024), + conns: make(chan acceptConn, 1024), } } @@ -31,9 +39,10 @@ func (ln *InmemoryListener) Accept() (net.Conn, error) { c, ok := <-ln.conns if !ok { - return nil, fmt.Errorf("InmemoryListener is already closed: use of closed network connection") + return nil, ErrInmemoryListenerClosed } - return c, nil + close(c.accepted) + return c.conn, nil } // Close implements net.Listener's Close. @@ -45,7 +54,7 @@ close(ln.conns) ln.closed = true } else { - err = fmt.Errorf("InmemoryListener is already closed") + err = ErrInmemoryListenerClosed } ln.lock.Unlock() return err @@ -59,8 +68,9 @@ } } -// Dial creates new client<->server connection, enqueues server side -// of the connection to Accept and returns client side of the connection. +// Dial creates new client<->server connection. +// Just like a real Dial it only returns once the server +// has accepted the connection. // // It is safe calling Dial from concurrently running goroutines. func (ln *InmemoryListener) Dial() (net.Conn, error) { @@ -68,17 +78,20 @@ cConn := pc.Conn1() sConn := pc.Conn2() ln.lock.Lock() + accepted := make(chan struct{}) if !ln.closed { - ln.conns <- sConn + ln.conns <- acceptConn{sConn, accepted} + // Wait until the connection has been accepted. + <-accepted } else { - sConn.Close() - cConn.Close() + sConn.Close() //nolint:errcheck + cConn.Close() //nolint:errcheck cConn = nil } ln.lock.Unlock() if cConn == nil { - return nil, fmt.Errorf("InmemoryListener is already closed") + return nil, ErrInmemoryListenerClosed } return cConn, nil } diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/inmemory_listener_test.go golang-github-valyala-fasthttp-1.31.0/fasthttputil/inmemory_listener_test.go --- golang-github-valyala-fasthttp-20160617/fasthttputil/inmemory_listener_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/inmemory_listener_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -2,12 +2,20 @@ import ( "bytes" + "context" "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "sync" "testing" "time" ) func TestInmemoryListener(t *testing.T) { + t.Parallel() + ln := NewInmemoryListener() ch := make(chan struct{}) @@ -15,29 +23,29 @@ go func(n int) { conn, err := ln.Dial() if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } defer conn.Close() req := fmt.Sprintf("request_%d", n) nn, err := conn.Write([]byte(req)) if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } if nn != len(req) { - t.Fatalf("unexpected number of bytes written: %d. Expecting %d", nn, len(req)) + t.Errorf("unexpected number of bytes written: %d. Expecting %d", nn, len(req)) } buf := make([]byte, 30) nn, err = conn.Read(buf) if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } buf = buf[:nn] resp := fmt.Sprintf("response_%d", n) if nn != len(resp) { - t.Fatalf("unexpected number of bytes read: %d. Expecting %d", nn, len(resp)) + t.Errorf("unexpected number of bytes read: %d. Expecting %d", nn, len(resp)) } if string(buf) != resp { - t.Fatalf("unexpected response %q. Expecting %q", buf, resp) + t.Errorf("unexpected response %q. Expecting %q", buf, resp) } ch <- struct{}{} }(i) @@ -55,19 +63,19 @@ buf := make([]byte, 30) n, err := conn.Read(buf) if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } buf = buf[:n] if !bytes.HasPrefix(buf, []byte("request_")) { - t.Fatalf("unexpected request prefix %q. Expecting %q", buf, "request_") + t.Errorf("unexpected request prefix %q. Expecting %q", buf, "request_") } resp := fmt.Sprintf("response_%s", buf[len("request_"):]) n, err = conn.Write([]byte(resp)) if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } if n != len(resp) { - t.Fatalf("unexpected number of bytes written: %d. Expecting %d", n, len(resp)) + t.Errorf("unexpected number of bytes written: %d. Expecting %d", n, len(resp)) } } }() @@ -90,3 +98,95 @@ t.Fatalf("timeout") } } + +// echoServerHandler implements http.Handler. +type echoServerHandler struct { + t *testing.T +} + +func (s *echoServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + time.Sleep(time.Millisecond * 100) + if _, err := io.Copy(w, r.Body); err != nil { + s.t.Fatalf("unexpected error: %s", err) + } +} + +func testInmemoryListenerHTTP(t *testing.T, f func(t *testing.T, client *http.Client)) { + ln := NewInmemoryListener() + defer ln.Close() + + client := &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return ln.Dial() + }, + }, + Timeout: time.Second, + } + + server := &http.Server{ + Handler: &echoServerHandler{t}, + } + + go func() { + if err := server.Serve(ln); err != nil && err != http.ErrServerClosed { + t.Errorf("unexpected error: %s", err) + } + }() + + f(t, client) + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100) + defer cancel() + server.Shutdown(ctx) //nolint:errcheck +} + +func testInmemoryListenerHTTPSingle(t *testing.T, client *http.Client, content string) { + res, err := client.Post("http://...", "text/plain", bytes.NewBufferString(content)) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + s := string(b) + if string(b) != content { + t.Fatalf("unexpected response %s, expecting %s", s, content) + } +} + +func TestInmemoryListenerHTTPSingle(t *testing.T) { + t.Parallel() + + testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) { + testInmemoryListenerHTTPSingle(t, client, "request") + }) +} + +func TestInmemoryListenerHTTPSerial(t *testing.T) { + t.Parallel() + + testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) { + for i := 0; i < 10; i++ { + testInmemoryListenerHTTPSingle(t, client, fmt.Sprintf("request_%d", i)) + } + }) +} + +func TestInmemoryListenerHTTPConcurrent(t *testing.T) { + t.Parallel() + + testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) { + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + testInmemoryListenerHTTPSingle(t, client, fmt.Sprintf("request_%d", i)) + }(i) + } + wg.Wait() + }) +} diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/inmemory_listener_timing_test.go golang-github-valyala-fasthttp-1.31.0/fasthttputil/inmemory_listener_timing_test.go --- golang-github-valyala-fasthttp-20160617/fasthttputil/inmemory_listener_timing_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/inmemory_listener_timing_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,6 +1,7 @@ package fasthttputil_test import ( + "crypto/tls" "net" "testing" @@ -36,14 +37,124 @@ // for fasthttp client and server. // // It re-establishes new TLS connection per each http request. -func BenchmarkTLSHandshake(b *testing.B) { - benchmark(b, handshakeHandler, true) +func BenchmarkTLSHandshakeRSAWithClientSessionCache(b *testing.B) { + bc := &benchConfig{ + IsTLS: true, + DisableClientSessionCache: false, + } + benchmarkExt(b, handshakeHandler, bc) +} + +func BenchmarkTLSHandshakeRSAWithoutClientSessionCache(b *testing.B) { + bc := &benchConfig{ + IsTLS: true, + DisableClientSessionCache: true, + } + benchmarkExt(b, handshakeHandler, bc) +} + +func BenchmarkTLSHandshakeECDSAWithClientSessionCache(b *testing.B) { + bc := &benchConfig{ + IsTLS: true, + DisableClientSessionCache: false, + UseECDSA: true, + } + benchmarkExt(b, handshakeHandler, bc) +} + +func BenchmarkTLSHandshakeECDSAWithoutClientSessionCache(b *testing.B) { + bc := &benchConfig{ + IsTLS: true, + DisableClientSessionCache: true, + UseECDSA: true, + } + benchmarkExt(b, handshakeHandler, bc) +} + +func BenchmarkTLSHandshakeECDSAWithCurvesWithClientSessionCache(b *testing.B) { + bc := &benchConfig{ + IsTLS: true, + DisableClientSessionCache: false, + UseCurves: true, + UseECDSA: true, + } + benchmarkExt(b, handshakeHandler, bc) +} + +func BenchmarkTLSHandshakeECDSAWithCurvesWithoutClientSessionCache(b *testing.B) { + bc := &benchConfig{ + IsTLS: true, + DisableClientSessionCache: true, + UseCurves: true, + UseECDSA: true, + } + benchmarkExt(b, handshakeHandler, bc) } func benchmark(b *testing.B, h fasthttp.RequestHandler, isTLS bool) { + bc := &benchConfig{ + IsTLS: isTLS, + } + benchmarkExt(b, h, bc) +} + +type benchConfig struct { + IsTLS bool + DisableClientSessionCache bool + UseCurves bool + UseECDSA bool +} + +func benchmarkExt(b *testing.B, h fasthttp.RequestHandler, bc *benchConfig) { + var serverTLSConfig, clientTLSConfig *tls.Config + if bc.IsTLS { + certFile := "rsa.pem" + keyFile := "rsa.key" + if bc.UseECDSA { + certFile = "ecdsa.pem" + keyFile = "ecdsa.key" + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + b.Fatalf("cannot load TLS certificate from certFile=%q, keyFile=%q: %s", certFile, keyFile, err) + } + serverTLSConfig = &tls.Config{ + Certificates: []tls.Certificate{cert}, + PreferServerCipherSuites: true, + } + serverTLSConfig.CurvePreferences = []tls.CurveID{} + if bc.UseCurves { + serverTLSConfig.CurvePreferences = []tls.CurveID{ + tls.CurveP256, + } + } + clientTLSConfig = &tls.Config{ + InsecureSkipVerify: true, + } + if bc.DisableClientSessionCache { + clientTLSConfig.ClientSessionCache = fakeSessionCache{} + } + } ln := fasthttputil.NewInmemoryListener() - serverStopCh := startServer(b, ln, h, isTLS) - c := newClient(ln, isTLS) + serverStopCh := make(chan struct{}) + go func() { + serverLn := net.Listener(ln) + if serverTLSConfig != nil { + serverLn = tls.NewListener(serverLn, serverTLSConfig) + } + if err := fasthttp.Serve(serverLn, h); err != nil { + b.Errorf("unexpected error in server: %s", err) + } + close(serverStopCh) + }() + c := &fasthttp.HostClient{ + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + IsTLS: clientTLSConfig != nil, + TLSConfig: clientTLSConfig, + } + b.RunParallel(func(pb *testing.PB) { runRequests(b, pb, c) }) @@ -52,7 +163,7 @@ } func streamingHandler(ctx *fasthttp.RequestCtx) { - ctx.WriteString("foobar") + ctx.WriteString("foobar") //nolint:errcheck } func handshakeHandler(ctx *fasthttp.RequestCtx) { @@ -62,37 +173,6 @@ ctx.SetConnectionClose() } -func startServer(b *testing.B, ln *fasthttputil.InmemoryListener, h fasthttp.RequestHandler, isTLS bool) <-chan struct{} { - ch := make(chan struct{}) - go func() { - var err error - if isTLS { - err = fasthttp.ServeTLS(ln, certFile, keyFile, h) - } else { - err = fasthttp.Serve(ln, h) - } - if err != nil { - b.Fatalf("unexpected error in server: %s", err) - } - close(ch) - }() - return ch -} - -const ( - certFile = "./ssl-cert-snakeoil.pem" - keyFile = "./ssl-cert-snakeoil.key" -) - -func newClient(ln *fasthttputil.InmemoryListener, isTLS bool) *fasthttp.HostClient { - return &fasthttp.HostClient{ - Dial: func(addr string) (net.Conn, error) { - return ln.Dial() - }, - IsTLS: isTLS, - } -} - func runRequests(b *testing.B, pb *testing.PB, c *fasthttp.HostClient) { var req fasthttp.Request req.SetRequestURI("http://foo.bar/baz") @@ -106,3 +186,13 @@ } } } + +type fakeSessionCache struct{} + +func (fakeSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) { + return nil, false +} + +func (fakeSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) { + // no-op +} diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/pipeconns.go golang-github-valyala-fasthttp-1.31.0/fasthttputil/pipeconns.go --- golang-github-valyala-fasthttp-20160617/fasthttputil/pipeconns.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/pipeconns.go 2021-10-09 18:39:05.000000000 +0000 @@ -8,7 +8,9 @@ "time" ) -// NewPipeConns returns new bi-directonal connection pipe. +// NewPipeConns returns new bi-directional connection pipe. +// +// PipeConns is NOT safe for concurrent use by multiple goroutines! func NewPipeConns() *PipeConns { ch1 := make(chan *byteBuffer, 4) ch2 := make(chan *byteBuffer, 4) @@ -36,6 +38,9 @@ // * It is faster. // * It buffers Write calls, so there is no need to have concurrent goroutine // calling Read in order to unblock each Write call. +// * It supports read and write deadlines. +// +// PipeConns is NOT safe for concurrent use by multiple goroutines! type PipeConns struct { c1 pipeConn c2 pipeConn @@ -79,6 +84,14 @@ rCh chan *byteBuffer wCh chan *byteBuffer pc *PipeConns + + readDeadlineTimer *time.Timer + writeDeadlineTimer *time.Timer + + readDeadlineCh <-chan time.Time + writeDeadlineCh <-chan time.Time + + readDeadlineChLock sync.Mutex } func (c *pipeConn) Write(p []byte) (int, error) { @@ -97,6 +110,9 @@ default: select { case c.wCh <- b: + case <-c.writeDeadlineCh: + c.writeDeadlineCh = closedDeadlineCh + return 0, ErrTimeout case <-c.pc.stopCh: releaseByteBuffer(b) return 0, errConnectionClosed @@ -147,10 +163,30 @@ if !mayBlock { return errWouldBlock } + c.readDeadlineChLock.Lock() + readDeadlineCh := c.readDeadlineCh + c.readDeadlineChLock.Unlock() select { case c.b = <-c.rCh: + case <-readDeadlineCh: + c.readDeadlineChLock.Lock() + c.readDeadlineCh = closedDeadlineCh + c.readDeadlineChLock.Unlock() + // rCh may contain data when deadline is reached. + // Read the data before returning ErrTimeout. + select { + case c.b = <-c.rCh: + default: + return ErrTimeout + } case <-c.pc.stopCh: - return io.EOF + // rCh may contain data when stopCh is closed. + // Read the data before returning EOF. + select { + case c.b = <-c.rCh: + default: + return io.EOF + } } } @@ -161,7 +197,26 @@ var ( errWouldBlock = errors.New("would block") errConnectionClosed = errors.New("connection closed") - errNoDeadlines = errors.New("deadline not supported") +) + +type timeoutError struct { +} + +func (e *timeoutError) Error() string { + return "timeout" +} + +// Only implement the Timeout() function of the net.Error interface. +// This allows for checks like: +// +// if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() { +func (e *timeoutError) Timeout() bool { + return true +} + +var ( + // ErrTimeout is returned from Read() or Write() on timeout. + ErrTimeout = &timeoutError{} ) func (c *pipeConn) Close() error { @@ -176,18 +231,55 @@ return pipeAddr(0) } -func (c *pipeConn) SetDeadline(t time.Time) error { - return errNoDeadlines +func (c *pipeConn) SetDeadline(deadline time.Time) error { + c.SetReadDeadline(deadline) //nolint:errcheck + c.SetWriteDeadline(deadline) //nolint:errcheck + return nil } -func (c *pipeConn) SetReadDeadline(t time.Time) error { - return c.SetDeadline(t) +func (c *pipeConn) SetReadDeadline(deadline time.Time) error { + if c.readDeadlineTimer == nil { + c.readDeadlineTimer = time.NewTimer(time.Hour) + } + readDeadlineCh := updateTimer(c.readDeadlineTimer, deadline) + c.readDeadlineChLock.Lock() + c.readDeadlineCh = readDeadlineCh + c.readDeadlineChLock.Unlock() + return nil } -func (c *pipeConn) SetWriteDeadline(t time.Time) error { - return c.SetDeadline(t) +func (c *pipeConn) SetWriteDeadline(deadline time.Time) error { + if c.writeDeadlineTimer == nil { + c.writeDeadlineTimer = time.NewTimer(time.Hour) + } + c.writeDeadlineCh = updateTimer(c.writeDeadlineTimer, deadline) + return nil } +func updateTimer(t *time.Timer, deadline time.Time) <-chan time.Time { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + if deadline.IsZero() { + return nil + } + d := -time.Since(deadline) + if d <= 0 { + return closedDeadlineCh + } + t.Reset(d) + return t.C +} + +var closedDeadlineCh = func() <-chan time.Time { + ch := make(chan time.Time) + close(ch) + return ch +}() + type pipeAddr int func (pipeAddr) Network() string { diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/pipeconns_test.go golang-github-valyala-fasthttp-1.31.0/fasthttputil/pipeconns_test.go --- golang-github-valyala-fasthttp-20160617/fasthttputil/pipeconns_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/pipeconns_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,6 +1,7 @@ package fasthttputil import ( + "bytes" "fmt" "io" "io/ioutil" @@ -9,7 +10,120 @@ "time" ) +func TestPipeConnsWriteTimeout(t *testing.T) { + t.Parallel() + + pc := NewPipeConns() + c1 := pc.Conn1() + + deadline := time.Now().Add(time.Millisecond) + if err := c1.SetWriteDeadline(deadline); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + data := []byte("foobar") + for { + _, err := c1.Write(data) + if err != nil { + if err == ErrTimeout { + break + } + t.Fatalf("unexpected error: %s", err) + } + } + + for i := 0; i < 10; i++ { + _, err := c1.Write(data) + if err == nil { + t.Fatalf("expecting error") + } + if err != ErrTimeout { + t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout) + } + } + + // read the written data + c2 := pc.Conn2() + if err := c2.SetReadDeadline(time.Now().Add(10 * time.Millisecond)); err != nil { + t.Fatalf("unexpected error: %s", err) + } + for { + _, err := c2.Read(data) + if err != nil { + if err == ErrTimeout { + break + } + t.Fatalf("unexpected error: %s", err) + } + } + + for i := 0; i < 10; i++ { + _, err := c2.Read(data) + if err == nil { + t.Fatalf("expecting error") + } + if err != ErrTimeout { + t.Fatalf("unexpected error: %s. Expecting %s", err, ErrTimeout) + } + } +} + +func TestPipeConnsPositiveReadTimeout(t *testing.T) { + t.Parallel() + + testPipeConnsReadTimeout(t, time.Millisecond) +} + +func TestPipeConnsNegativeReadTimeout(t *testing.T) { + t.Parallel() + + testPipeConnsReadTimeout(t, -time.Second) +} + +var zeroTime time.Time + +func testPipeConnsReadTimeout(t *testing.T, timeout time.Duration) { + pc := NewPipeConns() + c1 := pc.Conn1() + + deadline := time.Now().Add(timeout) + if err := c1.SetReadDeadline(deadline); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var buf [1]byte + for i := 0; i < 10; i++ { + _, err := c1.Read(buf[:]) + if err == nil { + t.Fatalf("expecting error on iteration %d", i) + } + if err != ErrTimeout { + t.Fatalf("unexpected error on iteration %d: %s. Expecting %s", i, err, ErrTimeout) + } + } + + // disable deadline and send data from c2 to c1 + if err := c1.SetReadDeadline(zeroTime); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + data := []byte("foobar") + c2 := pc.Conn2() + if _, err := c2.Write(data); err != nil { + t.Fatalf("unexpected error: %s", err) + } + dataBuf := make([]byte, len(data)) + if _, err := io.ReadFull(c1, dataBuf); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !bytes.Equal(data, dataBuf) { + t.Fatalf("unexpected data received: %q. Expecting %q", dataBuf, data) + } +} + func TestPipeConnsCloseWhileReadWriteConcurrent(t *testing.T) { + t.Parallel() + concurrency := 4 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { @@ -22,13 +136,15 @@ for i := 0; i < concurrency; i++ { select { case <-ch: - case <-time.After(3 * time.Second): + case <-time.After(5 * time.Second): t.Fatalf("timeout") } } } func TestPipeConnsCloseWhileReadWriteSerial(t *testing.T) { + t.Parallel() + testPipeConnsCloseWhileReadWriteSerial(t) } @@ -99,10 +215,14 @@ } func TestPipeConnsReadWriteSerial(t *testing.T) { + t.Parallel() + testPipeConnsReadWriteSerial(t) } func TestPipeConnsReadWriteConcurrent(t *testing.T) { + t.Parallel() + testConcurrency(t, 10, testPipeConnsReadWriteSerial) } @@ -156,10 +276,14 @@ } func TestPipeConnsCloseSerial(t *testing.T) { + t.Parallel() + testPipeConnsCloseSerial(t) } func TestPipeConnsCloseConcurrent(t *testing.T) { + t.Parallel() + testConcurrency(t, 10, testPipeConnsCloseSerial) } diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/rsa.key golang-github-valyala-fasthttp-1.31.0/fasthttputil/rsa.key --- golang-github-valyala-fasthttp-20160617/fasthttputil/rsa.key 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/rsa.key 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD4IQusAs8PJdnG +3mURt/AXtgC+ceqLOatJ49JJE1VPTkMAy+oE1f1XvkMrYsHqmDf6GWVzgVXryL4U +wq2/nJSm56ddhN55nI8oSN3dtywUB8/ShelEN73nlN77PeD9tl6NksPwWaKrqxq0 +FlabRPZSQCfmgZbhDV8Sa8mfCkFU0G0lit6kLGceCKMvmW+9Bz7ebsYmVdmVMxmf +IJStFD44lWFTdUc65WISKEdW2ELcUefb0zOLw+0PCbXFGJH5x5ktksW8+BBk2Hkg +GeQRL/qPCccthbScO0VgNj3zJ3ZZL0ObSDAbvNDG85joeNjDNq5DT/BAZ0bOSbEF +sh+f9BAzAgMBAAECggEBAJWv2cq7Jw6MVwSRxYca38xuD6TUNBopgBvjREixURW2 +sNUaLuMb9Omp7fuOaE2N5rcJ+xnjPGIxh/oeN5MQctz9gwn3zf6vY+15h97pUb4D +uGvYPRDaT8YVGS+X9NMZ4ZCmqW2lpWzKnCFoGHcy8yZLbcaxBsRdvKzwOYGoPiFb +K2QuhXZ/1UPmqK9i2DFKtj40X6vBszTNboFxOVpXrPu0FJwLVSDf2hSZ4fMM0DH3 +YqwKcYf5te+hxGKgrqRA3tn0NCWii0in6QIwXMC+kMw1ebg/tZKqyDLMNptAK8J+ +DVw9m5X1seUHS5ehU/g2jrQrtK5WYn7MrFK4lBzlRwECgYEA/d1TeANYECDWRRDk +B0aaRZs87Rwl/J9PsvbsKvtU/bX+OfSOUjOa9iQBqn0LmU8GqusEET/QVUfocVwV +Bggf/5qDLxz100Rj0ags/yE/kNr0Bb31kkkKHFMnCT06YasR7qKllwrAlPJvQv9x +IzBKq+T/Dx08Wep9bCRSFhzRCnsCgYEA+jdeZXTDr/Vz+D2B3nAw1frqYFfGnEVY +wqmoK3VXMDkGuxsloO2rN+SyiUo3JNiQNPDub/t7175GH5pmKtZOlftePANsUjBj +wZ1D0rI5Bxu/71ibIUYIRVmXsTEQkh/ozoh3jXCZ9+bLgYiYx7789IUZZSokFQ3D +FICUT9KJ36kCgYAGoq9Y1rWJjmIrYfqj2guUQC+CfxbbGIrrwZqAsRsSmpwvhZ3m +tiSZxG0quKQB+NfSxdvQW5ulbwC7Xc3K35F+i9pb8+TVBdeaFkw+yu6vaZmxQLrX +fQM/pEjD7A7HmMIaO7QaU5SfEAsqdCTP56Y8AftMuNXn/8IRfo2KuGwaWwKBgFpU +ILzJoVdlad9E/Rw7LjYhZfkv1uBVXIyxyKcfrkEXZSmozDXDdxsvcZCEfVHM6Ipk +K/+7LuMcqp4AFEAEq8wTOdq6daFaHLkpt/FZK6M4TlruhtpFOPkoNc3e45eM83OT +6mziKINJC1CQ6m65sQHpBtjxlKMRG8rL/D6wx9s5AoGBAMRlqNPMwglT3hvDmsAt +9Lf9pdmhERUlHhD8bj8mDaBj2Aqv7f6VRJaYZqP403pKKQexuqcn80mtjkSAPFkN +Cj7BVt/RXm5uoxDTnfi26RF9F6yNDEJ7UU9+peBr99aazF/fTgW/1GcMkQnum8uV +c257YgaWmjK9uB0Y2r2VxS0G +-----END PRIVATE KEY----- diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/rsa.pem golang-github-valyala-fasthttp-1.31.0/fasthttputil/rsa.pem --- golang-github-valyala-fasthttp-20160617/fasthttputil/rsa.pem 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/rsa.pem 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICujCCAaKgAwIBAgIJAMbXnKZ/cikUMA0GCSqGSIb3DQEBCwUAMBUxEzARBgNV +BAMTCnVidW50dS5uYW4wHhcNMTUwMjA0MDgwMTM5WhcNMjUwMjAxMDgwMTM5WjAV +MRMwEQYDVQQDEwp1YnVudHUubmFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB +CgKCAQEA+CELrALPDyXZxt5lEbfwF7YAvnHqizmrSePSSRNVT05DAMvqBNX9V75D +K2LB6pg3+hllc4FV68i+FMKtv5yUpuenXYTeeZyPKEjd3bcsFAfP0oXpRDe955Te ++z3g/bZejZLD8Fmiq6satBZWm0T2UkAn5oGW4Q1fEmvJnwpBVNBtJYrepCxnHgij +L5lvvQc+3m7GJlXZlTMZnyCUrRQ+OJVhU3VHOuViEihHVthC3FHn29Mzi8PtDwm1 +xRiR+ceZLZLFvPgQZNh5IBnkES/6jwnHLYW0nDtFYDY98yd2WS9Dm0gwG7zQxvOY +6HjYwzauQ0/wQGdGzkmxBbIfn/QQMwIDAQABow0wCzAJBgNVHRMEAjAAMA0GCSqG +SIb3DQEBCwUAA4IBAQBQjKm/4KN/iTgXbLTL3i7zaxYXFLXsnT1tF+ay4VA8aj98 +L3JwRTciZ3A5iy/W4VSCt3eASwOaPWHKqDBB5RTtL73LoAqsWmO3APOGQAbixcQ2 +45GXi05OKeyiYRi1Nvq7Unv9jUkRDHUYVPZVSAjCpsXzPhFkmZoTRxmx5l0ZF7Li +K91lI5h+eFq0dwZwrmlPambyh1vQUi70VHv8DNToVU29kel7YLbxGbuqETfhrcy6 +X+Mha6RYITkAn5FqsZcKMsc9eYGEF4l3XV+oS7q6xfTxktYJMFTI18J0lQ2Lv/CI +whdMnYGntDQBE/iFCrJEGNsKGc38796GBOb5j+zd +-----END CERTIFICATE----- diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/ssl-cert-snakeoil.key golang-github-valyala-fasthttp-1.31.0/fasthttputil/ssl-cert-snakeoil.key --- golang-github-valyala-fasthttp-20160617/fasthttputil/ssl-cert-snakeoil.key 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/ssl-cert-snakeoil.key 1970-01-01 00:00:00.000000000 +0000 @@ -1,28 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD4IQusAs8PJdnG -3mURt/AXtgC+ceqLOatJ49JJE1VPTkMAy+oE1f1XvkMrYsHqmDf6GWVzgVXryL4U -wq2/nJSm56ddhN55nI8oSN3dtywUB8/ShelEN73nlN77PeD9tl6NksPwWaKrqxq0 -FlabRPZSQCfmgZbhDV8Sa8mfCkFU0G0lit6kLGceCKMvmW+9Bz7ebsYmVdmVMxmf -IJStFD44lWFTdUc65WISKEdW2ELcUefb0zOLw+0PCbXFGJH5x5ktksW8+BBk2Hkg -GeQRL/qPCccthbScO0VgNj3zJ3ZZL0ObSDAbvNDG85joeNjDNq5DT/BAZ0bOSbEF -sh+f9BAzAgMBAAECggEBAJWv2cq7Jw6MVwSRxYca38xuD6TUNBopgBvjREixURW2 -sNUaLuMb9Omp7fuOaE2N5rcJ+xnjPGIxh/oeN5MQctz9gwn3zf6vY+15h97pUb4D -uGvYPRDaT8YVGS+X9NMZ4ZCmqW2lpWzKnCFoGHcy8yZLbcaxBsRdvKzwOYGoPiFb -K2QuhXZ/1UPmqK9i2DFKtj40X6vBszTNboFxOVpXrPu0FJwLVSDf2hSZ4fMM0DH3 -YqwKcYf5te+hxGKgrqRA3tn0NCWii0in6QIwXMC+kMw1ebg/tZKqyDLMNptAK8J+ -DVw9m5X1seUHS5ehU/g2jrQrtK5WYn7MrFK4lBzlRwECgYEA/d1TeANYECDWRRDk -B0aaRZs87Rwl/J9PsvbsKvtU/bX+OfSOUjOa9iQBqn0LmU8GqusEET/QVUfocVwV -Bggf/5qDLxz100Rj0ags/yE/kNr0Bb31kkkKHFMnCT06YasR7qKllwrAlPJvQv9x -IzBKq+T/Dx08Wep9bCRSFhzRCnsCgYEA+jdeZXTDr/Vz+D2B3nAw1frqYFfGnEVY -wqmoK3VXMDkGuxsloO2rN+SyiUo3JNiQNPDub/t7175GH5pmKtZOlftePANsUjBj -wZ1D0rI5Bxu/71ibIUYIRVmXsTEQkh/ozoh3jXCZ9+bLgYiYx7789IUZZSokFQ3D -FICUT9KJ36kCgYAGoq9Y1rWJjmIrYfqj2guUQC+CfxbbGIrrwZqAsRsSmpwvhZ3m -tiSZxG0quKQB+NfSxdvQW5ulbwC7Xc3K35F+i9pb8+TVBdeaFkw+yu6vaZmxQLrX -fQM/pEjD7A7HmMIaO7QaU5SfEAsqdCTP56Y8AftMuNXn/8IRfo2KuGwaWwKBgFpU -ILzJoVdlad9E/Rw7LjYhZfkv1uBVXIyxyKcfrkEXZSmozDXDdxsvcZCEfVHM6Ipk -K/+7LuMcqp4AFEAEq8wTOdq6daFaHLkpt/FZK6M4TlruhtpFOPkoNc3e45eM83OT -6mziKINJC1CQ6m65sQHpBtjxlKMRG8rL/D6wx9s5AoGBAMRlqNPMwglT3hvDmsAt -9Lf9pdmhERUlHhD8bj8mDaBj2Aqv7f6VRJaYZqP403pKKQexuqcn80mtjkSAPFkN -Cj7BVt/RXm5uoxDTnfi26RF9F6yNDEJ7UU9+peBr99aazF/fTgW/1GcMkQnum8uV -c257YgaWmjK9uB0Y2r2VxS0G ------END PRIVATE KEY----- diff -Nru golang-github-valyala-fasthttp-20160617/fasthttputil/ssl-cert-snakeoil.pem golang-github-valyala-fasthttp-1.31.0/fasthttputil/ssl-cert-snakeoil.pem --- golang-github-valyala-fasthttp-20160617/fasthttputil/ssl-cert-snakeoil.pem 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fasthttputil/ssl-cert-snakeoil.pem 1970-01-01 00:00:00.000000000 +0000 @@ -1,17 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICujCCAaKgAwIBAgIJAMbXnKZ/cikUMA0GCSqGSIb3DQEBCwUAMBUxEzARBgNV -BAMTCnVidW50dS5uYW4wHhcNMTUwMjA0MDgwMTM5WhcNMjUwMjAxMDgwMTM5WjAV -MRMwEQYDVQQDEwp1YnVudHUubmFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB -CgKCAQEA+CELrALPDyXZxt5lEbfwF7YAvnHqizmrSePSSRNVT05DAMvqBNX9V75D -K2LB6pg3+hllc4FV68i+FMKtv5yUpuenXYTeeZyPKEjd3bcsFAfP0oXpRDe955Te -+z3g/bZejZLD8Fmiq6satBZWm0T2UkAn5oGW4Q1fEmvJnwpBVNBtJYrepCxnHgij -L5lvvQc+3m7GJlXZlTMZnyCUrRQ+OJVhU3VHOuViEihHVthC3FHn29Mzi8PtDwm1 -xRiR+ceZLZLFvPgQZNh5IBnkES/6jwnHLYW0nDtFYDY98yd2WS9Dm0gwG7zQxvOY -6HjYwzauQ0/wQGdGzkmxBbIfn/QQMwIDAQABow0wCzAJBgNVHRMEAjAAMA0GCSqG -SIb3DQEBCwUAA4IBAQBQjKm/4KN/iTgXbLTL3i7zaxYXFLXsnT1tF+ay4VA8aj98 -L3JwRTciZ3A5iy/W4VSCt3eASwOaPWHKqDBB5RTtL73LoAqsWmO3APOGQAbixcQ2 -45GXi05OKeyiYRi1Nvq7Unv9jUkRDHUYVPZVSAjCpsXzPhFkmZoTRxmx5l0ZF7Li -K91lI5h+eFq0dwZwrmlPambyh1vQUi70VHv8DNToVU29kel7YLbxGbuqETfhrcy6 -X+Mha6RYITkAn5FqsZcKMsc9eYGEF4l3XV+oS7q6xfTxktYJMFTI18J0lQ2Lv/CI -whdMnYGntDQBE/iFCrJEGNsKGc38796GBOb5j+zd ------END CERTIFICATE----- diff -Nru golang-github-valyala-fasthttp-20160617/fs.go golang-github-valyala-fasthttp-1.31.0/fs.go --- golang-github-valyala-fasthttp-20160617/fs.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fs.go 2021-10-09 18:39:05.000000000 +0000 @@ -16,7 +16,9 @@ "sync" "time" + "github.com/andybalholm/brotli" "github.com/klauspost/compress/gzip" + "github.com/valyala/bytebufferpool" ) // ServeFileBytesUncompressed returns HTTP response containing file contents @@ -83,12 +85,16 @@ }) if len(path) == 0 || path[0] != '/' { // extend relative path to absolute path + hasTrailingSlash := len(path) > 0 && path[len(path)-1] == '/' var err error if path, err = filepath.Abs(path); err != nil { ctx.Logger().Printf("cannot resolve path %q to absolute file path: %s", path, err) ctx.Error("Internal Server Error", StatusInternalServerError) return } + if hasTrailingSlash { + path += "/" + } } ctx.Request.SetRequestURI(path) rootFSHandler(ctx) @@ -100,6 +106,7 @@ Root: "/", GenerateIndexPages: true, Compress: true, + CompressBrotli: true, AcceptByteRange: true, } rootFSHandler RequestHandler @@ -139,12 +146,12 @@ if len(host) == 0 { host = strInvalidHost } - b := AcquireByteBuffer() + b := bytebufferpool.Get() b.B = append(b.B, '/') b.B = append(b.B, host...) b.B = append(b.B, path...) ctx.URI().SetPathBytes(b.B) - ReleaseByteBuffer(b) + bytebufferpool.Put(b) return ctx.Path() } @@ -193,7 +200,7 @@ // // It is prohibited copying FS values. Create new values instead. type FS struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck // Path to the root directory to serve files from. Root string @@ -225,12 +232,19 @@ // It adds CompressedFileSuffix suffix to the original file name and // tries saving the resulting compressed file under the new file name. // So it is advisable to give the server write access to Root - // and to all inner folders in order to minimze CPU usage when serving + // and to all inner folders in order to minimize CPU usage when serving // compressed responses. // // Transparent compression is disabled by default. Compress bool + // Uses brotli encoding and fallbacks to gzip in responses if set to true, uses gzip if set to false. + // + // This value has sense only if Compress is set. + // + // Brotli encoding is disabled by default. + CompressBrotli bool + // Enables byte range requests if set to true. // // Byte range requests are disabled by default. @@ -241,6 +255,14 @@ // By default request path is not modified. PathRewrite PathRewriteFunc + // PathNotFound fires when file is not found in filesystem + // this functions tries to replace "Cannot open requested path" + // server response giving to the programmer the control of server flow. + // + // By default PathNotFound returns + // "Cannot open requested path" + PathNotFound RequestHandler + // Expiration duration for inactive file handlers. // // FSHandlerCacheDuration is used by default. @@ -253,6 +275,18 @@ // FSCompressedFileSuffix is used by default. CompressedFileSuffix string + // Suffixes list to add to compressedFileSuffix depending on encoding + // + // This value has sense only if Compress is set. + // + // FSCompressedFileSuffixes is used by default. + CompressedFileSuffixes map[string]string + + // If CleanStop is set, the channel can be closed to stop the cleanup handlers + // for the FS RequestHandlers created with NewRequestHandler. + // NEVER close this channel while the handler is still being used! + CleanStop chan struct{} + once sync.Once h RequestHandler } @@ -262,6 +296,14 @@ // See FS.Compress for details. const FSCompressedFileSuffix = ".fasthttp.gz" +// FSCompressedFileSuffixes is the suffixes FS adds to the original file names depending on encoding +// when trying to store compressed file under the new file name. +// See FS.Compress for details. +var FSCompressedFileSuffixes = map[string]string{ + "gzip": ".fasthttp.gz", + "br": ".fasthttp.br", +} + // FSHandlerCacheDuration is the default expiration duration for inactive // file handlers opened by FS. const FSHandlerCacheDuration = 10 * time.Second @@ -332,29 +374,59 @@ if cacheDuration <= 0 { cacheDuration = FSHandlerCacheDuration } - compressedFileSuffix := fs.CompressedFileSuffix - if len(compressedFileSuffix) == 0 { - compressedFileSuffix = FSCompressedFileSuffix + + compressedFileSuffixes := fs.CompressedFileSuffixes + if len(compressedFileSuffixes["br"]) == 0 || len(compressedFileSuffixes["gzip"]) == 0 || + compressedFileSuffixes["br"] == compressedFileSuffixes["gzip"] { + compressedFileSuffixes = FSCompressedFileSuffixes + } + + if len(fs.CompressedFileSuffix) > 0 { + compressedFileSuffixes["gzip"] = fs.CompressedFileSuffix + compressedFileSuffixes["br"] = FSCompressedFileSuffixes["br"] } h := &fsHandler{ - root: root, - indexNames: fs.IndexNames, - pathRewrite: fs.PathRewrite, - generateIndexPages: fs.GenerateIndexPages, - compress: fs.Compress, - acceptByteRange: fs.AcceptByteRange, - cacheDuration: cacheDuration, - compressedFileSuffix: compressedFileSuffix, - cache: make(map[string]*fsFile), - compressedCache: make(map[string]*fsFile), + root: root, + indexNames: fs.IndexNames, + pathRewrite: fs.PathRewrite, + generateIndexPages: fs.GenerateIndexPages, + compress: fs.Compress, + compressBrotli: fs.CompressBrotli, + pathNotFound: fs.PathNotFound, + acceptByteRange: fs.AcceptByteRange, + cacheDuration: cacheDuration, + compressedFileSuffixes: compressedFileSuffixes, + cache: make(map[string]*fsFile), + cacheBrotli: make(map[string]*fsFile), + cacheGzip: make(map[string]*fsFile), } go func() { var pendingFiles []*fsFile + + clean := func() { + pendingFiles = h.cleanCache(pendingFiles) + } + + if fs.CleanStop != nil { + t := time.NewTicker(cacheDuration / 2) + for { + select { + case <-t.C: + clean() + case _, stillOpen := <-fs.CleanStop: + // Ignore values send on the channel, only stop when it is closed. + if !stillOpen { + t.Stop() + return + } + } + } + } for { time.Sleep(cacheDuration / 2) - pendingFiles = h.cleanCache(pendingFiles) + clean() } }() @@ -362,18 +434,21 @@ } type fsHandler struct { - root string - indexNames []string - pathRewrite PathRewriteFunc - generateIndexPages bool - compress bool - acceptByteRange bool - cacheDuration time.Duration - compressedFileSuffix string - - cache map[string]*fsFile - compressedCache map[string]*fsFile - cacheLock sync.Mutex + root string + indexNames []string + pathRewrite PathRewriteFunc + pathNotFound RequestHandler + generateIndexPages bool + compress bool + compressBrotli bool + acceptByteRange bool + cacheDuration time.Duration + compressedFileSuffixes map[string]string + + cache map[string]*fsFile + cacheBrotli map[string]*fsFile + cacheGzip map[string]*fsFile + cacheLock sync.Mutex smallFileReaderPool sync.Pool } @@ -593,7 +668,7 @@ curPos := r.startPos bufv := copyBufPool.Get() buf := bufv.([]byte) - for err != nil { + for err == nil { tailLen := r.endPos - curPos if tailLen <= 0 { break @@ -637,7 +712,8 @@ pendingFiles = remainingFiles pendingFiles, filesToRelease = cleanCacheNolock(h.cache, pendingFiles, filesToRelease, h.cacheDuration) - pendingFiles, filesToRelease = cleanCacheNolock(h.compressedCache, pendingFiles, filesToRelease, h.cacheDuration) + pendingFiles, filesToRelease = cleanCacheNolock(h.cacheBrotli, pendingFiles, filesToRelease, h.cacheDuration) + pendingFiles, filesToRelease = cleanCacheNolock(h.cacheGzip, pendingFiles, filesToRelease, h.cacheDuration) h.cacheLock.Unlock() @@ -673,6 +749,7 @@ } else { path = ctx.Path() } + hasTrailingSlash := len(path) > 0 && path[len(path)-1] == '/' path = stripTrailingSlashes(path) if n := bytes.IndexByte(path, 0); n >= 0 { @@ -693,10 +770,18 @@ mustCompress := false fileCache := h.cache + fileEncoding := "" byteRange := ctx.Request.Header.peek(strRange) - if len(byteRange) == 0 && h.compress && ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { - mustCompress = true - fileCache = h.compressedCache + if len(byteRange) == 0 && h.compress { + if h.compressBrotli && ctx.Request.Header.HasAcceptEncodingBytes(strBr) { + mustCompress = true + fileCache = h.cacheBrotli + fileEncoding = "br" + } else if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { + mustCompress = true + fileCache = h.cacheGzip + fileEncoding = "gzip" + } } h.cacheLock.Lock() @@ -710,15 +795,19 @@ pathStr := string(path) filePath := h.root + pathStr var err error - ff, err = h.openFSFile(filePath, mustCompress) + ff, err = h.openFSFile(filePath, mustCompress, fileEncoding) if mustCompress && err == errNoCreatePermission { ctx.Logger().Printf("insufficient permissions for saving compressed file for %q. Serving uncompressed file. "+ "Allow write access to the directory with this file in order to improve fasthttp performance", filePath) mustCompress = false - ff, err = h.openFSFile(filePath, mustCompress) + ff, err = h.openFSFile(filePath, mustCompress, fileEncoding) } if err == errDirIndexRequired { - ff, err = h.openIndexFile(ctx, filePath, mustCompress) + if !hasTrailingSlash { + ctx.RedirectBytes(append(path, '/'), StatusFound) + return + } + ff, err = h.openIndexFile(ctx, filePath, mustCompress, fileEncoding) if err != nil { ctx.Logger().Printf("cannot open dir index %q: %s", filePath, err) ctx.Error("Directory index is forbidden", StatusForbidden) @@ -726,7 +815,12 @@ } } else if err != nil { ctx.Logger().Printf("cannot open file %q: %s", filePath, err) - ctx.Error("Cannot open requested path", StatusNotFound) + if h.pathNotFound == nil { + ctx.Error("Cannot open requested path", StatusNotFound) + } else { + ctx.SetStatusCode(StatusNotFound) + h.pathNotFound(ctx) + } return } @@ -764,7 +858,11 @@ hdr := &ctx.Response.Header if ff.compressed { - hdr.SetCanonical(strContentEncoding, strGzip) + if fileEncoding == "br" { + hdr.SetCanonical(strContentEncoding, strBr) + } else if fileEncoding == "gzip" { + hdr.SetCanonical(strContentEncoding, strGzip) + } } statusCode := StatusOK @@ -808,7 +906,10 @@ } } } - ctx.SetContentType(ff.contentType) + hdr.noDefaultContentType = true + if len(hdr.ContentType()) == 0 { + ctx.SetContentType(ff.contentType) + } ctx.SetStatusCode(statusCode) } @@ -872,10 +973,10 @@ return startPos, endPos, nil } -func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress bool) (*fsFile, error) { +func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress bool, fileEncoding string) (*fsFile, error) { for _, indexName := range h.indexNames { indexFilePath := dirPath + "/" + indexName - ff, err := h.openFSFile(indexFilePath, mustCompress) + ff, err := h.openFSFile(indexFilePath, mustCompress, fileEncoding) if err == nil { return ff, nil } @@ -888,7 +989,7 @@ return nil, fmt.Errorf("cannot access directory without index page. Directory %q", dirPath) } - return h.createDirIndex(ctx.URI(), dirPath, mustCompress) + return h.createDirIndex(ctx.URI(), dirPath, mustCompress, fileEncoding) } var ( @@ -896,8 +997,8 @@ errNoCreatePermission = errors.New("no 'create file' permissions") ) -func (h *fsHandler) createDirIndex(base *URI, dirPath string, mustCompress bool) (*fsFile, error) { - w := &ByteBuffer{} +func (h *fsHandler) createDirIndex(base *URI, dirPath string, mustCompress bool, fileEncoding string) (*fsFile, error) { + w := &bytebufferpool.ByteBuffer{} basePathEscaped := html.EscapeString(string(base.Path())) fmt.Fprintf(w, "%s", basePathEscaped) @@ -924,12 +1025,15 @@ } fm := make(map[string]os.FileInfo, len(fileinfos)) - var filenames []string + filenames := make([]string, 0, len(fileinfos)) +nestedContinue: for _, fi := range fileinfos { name := fi.Name() - if strings.HasSuffix(name, h.compressedFileSuffix) { - // Do not show compressed files on index page. - continue + for _, cfs := range h.compressedFileSuffixes { + if strings.HasSuffix(name, cfs) { + // Do not show compressed files on index page. + continue nestedContinue + } } fm[name] = fi filenames = append(filenames, name) @@ -939,7 +1043,7 @@ base.CopyTo(&u) u.Update(string(u.Path()) + "/") - sort.Sort(sort.StringSlice(filenames)) + sort.Strings(filenames) for _, name := range filenames { u.Update(name) pathEscaped := html.EscapeString(string(u.Path())) @@ -957,12 +1061,11 @@ fmt.Fprintf(w, "") if mustCompress { - var zbuf ByteBuffer - zw := acquireGzipWriter(&zbuf, CompressDefaultCompression) - _, err = zw.Write(w.B) - releaseGzipWriter(zw) - if err != nil { - return nil, fmt.Errorf("error when compressing automatically generated index for directory %q: %s", dirPath, err) + var zbuf bytebufferpool.ByteBuffer + if fileEncoding == "br" { + zbuf.B = AppendBrotliBytesLevel(zbuf.B, w.B, CompressDefaultCompression) + } else if fileEncoding == "gzip" { + zbuf.B = AppendGzipBytesLevel(zbuf.B, w.B, CompressDefaultCompression) } w = &zbuf } @@ -988,7 +1091,7 @@ fsMaxCompressibleFileSize = 8 * 1024 * 1024 ) -func (h *fsHandler) compressAndOpenFSFile(filePath string) (*fsFile, error) { +func (h *fsHandler) compressAndOpenFSFile(filePath string, fileEncoding string) (*fsFile, error) { f, err := os.Open(filePath) if err != nil { return nil, err @@ -1005,13 +1108,13 @@ return nil, errDirIndexRequired } - if strings.HasSuffix(filePath, h.compressedFileSuffix) || + if strings.HasSuffix(filePath, h.compressedFileSuffixes[fileEncoding]) || fileInfo.Size() > fsMaxCompressibleFileSize || !isFileCompressible(f, fsMinCompressRatio) { - return h.newFSFile(f, fileInfo, false) + return h.newFSFile(f, fileInfo, false, "") } - compressedFilePath := filePath + h.compressedFileSuffix + compressedFilePath := filePath + h.compressedFileSuffixes[fileEncoding] absPath, err := filepath.Abs(compressedFilePath) if err != nil { f.Close() @@ -1020,20 +1123,20 @@ flock := getFileLock(absPath) flock.Lock() - ff, err := h.compressFileNolock(f, fileInfo, filePath, compressedFilePath) + ff, err := h.compressFileNolock(f, fileInfo, filePath, compressedFilePath, fileEncoding) flock.Unlock() return ff, err } -func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePath, compressedFilePath string) (*fsFile, error) { +func (h *fsHandler) compressFileNolock(f *os.File, fileInfo os.FileInfo, filePath, compressedFilePath string, fileEncoding string) (*fsFile, error) { // Attempt to open compressed file created by another concurrent // goroutine. // It is safe opening such a file, since the file creation // is guarded by file mutex - see getFileLock call. if _, err := os.Stat(compressedFilePath); err == nil { f.Close() - return h.newCompressedFSFile(compressedFilePath) + return h.newCompressedFSFile(compressedFilePath, fileEncoding) } // Create temporary file, so concurrent goroutines don't use @@ -1047,13 +1150,21 @@ } return nil, errNoCreatePermission } - - zw := acquireGzipWriter(zf, CompressDefaultCompression) - _, err = copyZeroAlloc(zw, f) - if err1 := zw.Flush(); err == nil { - err = err1 + if fileEncoding == "br" { + zw := acquireStacklessBrotliWriter(zf, CompressDefaultCompression) + _, err = copyZeroAlloc(zw, f) + if err1 := zw.Flush(); err == nil { + err = err1 + } + releaseStacklessBrotliWriter(zw, CompressDefaultCompression) + } else if fileEncoding == "gzip" { + zw := acquireStacklessGzipWriter(zf, CompressDefaultCompression) + _, err = copyZeroAlloc(zw, f) + if err1 := zw.Flush(); err == nil { + err = err1 + } + releaseStacklessGzipWriter(zw, CompressDefaultCompression) } - releaseGzipWriter(zw) zf.Close() f.Close() if err != nil { @@ -1066,10 +1177,10 @@ if err = os.Rename(tmpFilePath, compressedFilePath); err != nil { return nil, fmt.Errorf("cannot move compressed file from %q to %q: %s", tmpFilePath, compressedFilePath, err) } - return h.newCompressedFSFile(compressedFilePath) + return h.newCompressedFSFile(compressedFilePath, fileEncoding) } -func (h *fsHandler) newCompressedFSFile(filePath string) (*fsFile, error) { +func (h *fsHandler) newCompressedFSFile(filePath string, fileEncoding string) (*fsFile, error) { f, err := os.Open(filePath) if err != nil { return nil, fmt.Errorf("cannot open compressed file %q: %s", filePath, err) @@ -1079,19 +1190,19 @@ f.Close() return nil, fmt.Errorf("cannot obtain info for compressed file %q: %s", filePath, err) } - return h.newFSFile(f, fileInfo, true) + return h.newFSFile(f, fileInfo, true, fileEncoding) } -func (h *fsHandler) openFSFile(filePath string, mustCompress bool) (*fsFile, error) { +func (h *fsHandler) openFSFile(filePath string, mustCompress bool, fileEncoding string) (*fsFile, error) { filePathOriginal := filePath if mustCompress { - filePath += h.compressedFileSuffix + filePath += h.compressedFileSuffixes[fileEncoding] } f, err := os.Open(filePath) if err != nil { if mustCompress && os.IsNotExist(err) { - return h.compressAndOpenFSFile(filePathOriginal) + return h.compressAndOpenFSFile(filePathOriginal, fileEncoding) } return nil, err } @@ -1106,7 +1217,7 @@ f.Close() if mustCompress { return nil, fmt.Errorf("directory with unexpected suffix found: %q. Suffix: %q", - filePath, h.compressedFileSuffix) + filePath, h.compressedFileSuffixes[fileEncoding]) } return nil, errDirIndexRequired } @@ -1118,18 +1229,21 @@ return nil, fmt.Errorf("cannot obtain info for original file %q: %s", filePathOriginal, err) } - if fileInfoOriginal.ModTime() != fileInfo.ModTime() { + // Only re-create the compressed file if there was more than a second between the mod times. + // On MacOS the gzip seems to truncate the nanoseconds in the mod time causing the original file + // to look newer than the gzipped file. + if fileInfoOriginal.ModTime().Sub(fileInfo.ModTime()) >= time.Second { // The compressed file became stale. Re-create it. f.Close() os.Remove(filePath) - return h.compressAndOpenFSFile(filePathOriginal) + return h.compressAndOpenFSFile(filePathOriginal, fileEncoding) } } - return h.newFSFile(f, fileInfo, mustCompress) + return h.newFSFile(f, fileInfo, mustCompress, fileEncoding) } -func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool) (*fsFile, error) { +func (h *fsHandler) newFSFile(f *os.File, fileInfo os.FileInfo, compressed bool, fileEncoding string) (*fsFile, error) { n := fileInfo.Size() contentLength := int(n) if n != int64(contentLength) { @@ -1138,10 +1252,10 @@ } // detect content-type - ext := fileExtension(fileInfo.Name(), compressed, h.compressedFileSuffix) + ext := fileExtension(fileInfo.Name(), compressed, h.compressedFileSuffixes[fileEncoding]) contentType := mime.TypeByExtension(ext) if len(contentType) == 0 { - data, err := readFileHeader(f, compressed) + data, err := readFileHeader(f, compressed, fileEncoding) if err != nil { return nil, fmt.Errorf("cannot read header of the file %q: %s", f.Name(), err) } @@ -1163,15 +1277,25 @@ return ff, nil } -func readFileHeader(f *os.File, compressed bool) ([]byte, error) { +func readFileHeader(f *os.File, compressed bool, fileEncoding string) ([]byte, error) { r := io.Reader(f) - var zr *gzip.Reader + var ( + br *brotli.Reader + zr *gzip.Reader + ) if compressed { var err error - if zr, err = acquireGzipReader(f); err != nil { - return nil, err + if fileEncoding == "br" { + if br, err = acquireBrotliReader(f); err != nil { + return nil, err + } + r = br + } else if fileEncoding == "gzip" { + if zr, err = acquireGzipReader(f); err != nil { + return nil, err + } + r = zr } - r = zr } lr := &io.LimitedReader{ @@ -1179,7 +1303,13 @@ N: 512, } data, err := ioutil.ReadAll(lr) - f.Seek(0, 0) + if _, err := f.Seek(0, 0); err != nil { + return nil, err + } + + if br != nil { + releaseBrotliReader(br) + } if zr != nil { releaseGzipReader(zr) diff -Nru golang-github-valyala-fasthttp-20160617/fs_test.go golang-github-valyala-fasthttp-1.31.0/fs_test.go --- golang-github-valyala-fasthttp-20160617/fs_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fs_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -4,15 +4,28 @@ "bufio" "bytes" "fmt" + "io" "io/ioutil" "math/rand" "os" + "path" + "runtime" "sort" "testing" "time" ) +type TestLogger struct { + t *testing.T +} + +func (t TestLogger) Printf(format string, args ...interface{}) { + t.t.Logf(format, args...) +} + func TestNewVHostPathRewriter(t *testing.T) { + t.Parallel() + var ctx RequestCtx var req Request req.Header.SetHost("foobar.com") @@ -37,6 +50,8 @@ } func TestNewVHostPathRewriterMaliciousHost(t *testing.T) { + t.Parallel() + var ctx RequestCtx var req Request req.Header.SetHost("/../../../etc/passwd") @@ -45,16 +60,63 @@ f := NewVHostPathRewriter(0) path := f(&ctx) - expectedPath := "/invalid-host/foo/bar/baz" + expectedPath := "/invalid-host/" if string(path) != expectedPath { t.Fatalf("unexpected path %q. Expecting %q", path, expectedPath) } } +func testPathNotFound(t *testing.T, pathNotFoundFunc RequestHandler) { + var ctx RequestCtx + var req Request + req.SetRequestURI("http//some.url/file") + ctx.Init(&req, nil, TestLogger{t}) + + stop := make(chan struct{}) + defer close(stop) + + fs := &FS{ + Root: "./", + PathNotFound: pathNotFoundFunc, + CleanStop: stop, + } + fs.NewRequestHandler()(&ctx) + + if pathNotFoundFunc == nil { + // different to ... + if !bytes.Equal(ctx.Response.Body(), + []byte("Cannot open requested path")) { + t.Fatalf("response defers. Response: %q", ctx.Response.Body()) + } + } else { + // Equals to ... + if bytes.Equal(ctx.Response.Body(), + []byte("Cannot open requested path")) { + t.Fatalf("response defers. Response: %q", ctx.Response.Body()) + } + } +} + +func TestPathNotFound(t *testing.T) { + t.Parallel() + + testPathNotFound(t, nil) +} + +func TestPathNotFoundFunc(t *testing.T) { + t.Parallel() + + testPathNotFound(t, func(ctx *RequestCtx) { + ctx.WriteString("Not found hehe") //nolint:errcheck + }) +} + func TestServeFileHead(t *testing.T) { + // This test can't run parallel as files in / might by changed by other tests. + var ctx RequestCtx var req Request - req.Header.SetMethod("HEAD") + req.Header.SetMethod(MethodHead) req.SetRequestURI("http://foobar.com/baz") ctx.Init(&req, nil, nil) @@ -68,7 +130,7 @@ t.Fatalf("unexpected error: %s", err) } - ce := resp.Header.Peek("Content-Encoding") + ce := resp.Header.Peek(HeaderContentEncoding) if len(ce) > 0 { t.Fatalf("Unexpected 'Content-Encoding' %q", ce) } @@ -88,23 +150,79 @@ } } -func TestServeFileCompressed(t *testing.T) { +func TestServeFileSmallNoReadFrom(t *testing.T) { + t.Parallel() + + teststr := "hello, world!" + + tempdir, err := ioutil.TempDir("", "httpexpect") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tempdir) + + if err := ioutil.WriteFile( + path.Join(tempdir, "hello"), []byte(teststr), 0666); err != nil { + t.Fatal(err) + } + var ctx RequestCtx var req Request req.SetRequestURI("http://foobar.com/baz") - req.Header.Set("Accept-Encoding", "gzip") ctx.Init(&req, nil, nil) - ServeFile(&ctx, "fs.go") + ServeFile(&ctx, path.Join(tempdir, "hello")) + + reader, ok := ctx.Response.bodyStream.(*fsSmallFileReader) + if !ok { + t.Fatal("expected fsSmallFileReader") + } + + buf := bytes.NewBuffer(nil) + + n, err := reader.WriteTo(pureWriter{buf}) + if err != nil { + t.Fatal(err) + } + + if n != int64(len(teststr)) { + t.Fatalf("expected %d bytes, got %d bytes", len(teststr), n) + } + + body := buf.String() + if body != teststr { + t.Fatalf("expected '%s'", teststr) + } +} + +type pureWriter struct { + w io.Writer +} + +func (pw pureWriter) Write(p []byte) (nn int, err error) { + return pw.w.Write(p) +} + +func TestServeFileCompressed(t *testing.T) { + // This test can't run parallel as files in / might by changed by other tests. + + var ctx RequestCtx + ctx.Init(&Request{}, nil, nil) var resp Response + + // request compressed gzip file + ctx.Request.SetRequestURI("http://foobar.com/baz") + ctx.Request.Header.Set(HeaderAcceptEncoding, "gzip") + ServeFile(&ctx, "fs.go") + s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } - ce := resp.Header.Peek("Content-Encoding") + ce := resp.Header.Peek(HeaderContentEncoding) if string(ce) != "gzip" { t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "gzip") } @@ -120,13 +238,44 @@ if !bytes.Equal(body, expectedBody) { t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) } + + // request compressed brotli file + ctx.Request.Reset() + ctx.Request.SetRequestURI("http://foobar.com/baz") + ctx.Request.Header.Set(HeaderAcceptEncoding, "br") + ServeFile(&ctx, "fs.go") + + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err = resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + ce = resp.Header.Peek(HeaderContentEncoding) + if string(ce) != "br" { + t.Fatalf("Unexpected 'Content-Encoding' %q. Expecting %q", ce, "br") + } + + body, err = resp.BodyUnbrotli() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + expectedBody, err = getFileContents("/fs.go") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !bytes.Equal(body, expectedBody) { + t.Fatalf("unexpected body %q. expecting %q", body, expectedBody) + } } func TestServeFileUncompressed(t *testing.T) { + // This test can't run parallel as files in / might by changed by other tests. + var ctx RequestCtx var req Request req.SetRequestURI("http://foobar.com/baz") - req.Header.Set("Accept-Encoding", "gzip") + req.Header.Set(HeaderAcceptEncoding, "gzip") ctx.Init(&req, nil, nil) ServeFileUncompressed(&ctx, "fs.go") @@ -138,7 +287,7 @@ t.Fatalf("unexpected error: %s", err) } - ce := resp.Header.Peek("Content-Encoding") + ce := resp.Header.Peek(HeaderContentEncoding) if len(ce) > 0 { t.Fatalf("Unexpected 'Content-Encoding' %q", ce) } @@ -154,9 +303,15 @@ } func TestFSByteRangeConcurrent(t *testing.T) { + // This test can't run parallel as files in / might by changed by other tests. + + stop := make(chan struct{}) + defer close(stop) + fs := &FS{ Root: ".", AcceptByteRange: true, + CleanStop: stop, } h := fs.NewRequestHandler() @@ -182,9 +337,15 @@ } func TestFSByteRangeSingleThread(t *testing.T) { + // This test can't run parallel as files in / might by changed by other tests. + + stop := make(chan struct{}) + defer close(stop) + fs := &FS{ Root: ".", AcceptByteRange: true, + CleanStop: stop, } h := fs.NewRequestHandler() @@ -221,7 +382,7 @@ if resp.StatusCode() != StatusPartialContent { t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusPartialContent, filePath) } - cr := resp.Header.Peek("Content-Range") + cr := resp.Header.Peek(HeaderContentRange) expectedCR := fmt.Sprintf("bytes %d-%d/%d", startPos, endPos, fileSize) if string(cr) != expectedCR { @@ -252,6 +413,8 @@ } func TestParseByteRangeSuccess(t *testing.T) { + t.Parallel() + testParseByteRangeSuccess(t, "bytes=0-0", 1, 0, 0) testParseByteRangeSuccess(t, "bytes=1234-6789", 6790, 1234, 6789) @@ -283,6 +446,8 @@ } func TestParseByteRangeError(t *testing.T) { + t.Parallel() + // invalid value testParseByteRangeError(t, "asdfasdfas", 1234) @@ -315,10 +480,17 @@ } func TestFSCompressConcurrent(t *testing.T) { + // This test can't run parallel as files in / might be changed by other tests. + + stop := make(chan struct{}) + defer close(stop) + fs := &FS{ Root: ".", GenerateIndexPages: true, Compress: true, + CompressBrotli: true, + CleanStop: stop, } h := fs.NewRequestHandler() @@ -338,17 +510,24 @@ for i := 0; i < concurrency; i++ { select { case <-ch: - case <-time.After(time.Second): + case <-time.After(time.Second * 3): t.Fatalf("timeout") } } } func TestFSCompressSingleThread(t *testing.T) { + // This test can't run parallel as files in / might by changed by other tests. + + stop := make(chan struct{}) + defer close(stop) + fs := &FS{ Root: ".", GenerateIndexPages: true, Compress: true, + CompressBrotli: true, + CleanStop: stop, } h := fs.NewRequestHandler() @@ -361,69 +540,80 @@ var ctx RequestCtx ctx.Init(&Request{}, nil, nil) + var resp Response + // request uncompressed file ctx.Request.Reset() ctx.Request.SetRequestURI(filePath) h(&ctx) - - var resp Response s := ctx.Response.String() br := bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s. filePath=%q", err, filePath) + t.Errorf("unexpected error: %s. filePath=%q", err, filePath) } if resp.StatusCode() != StatusOK { - t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) + t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) } - ce := resp.Header.Peek("Content-Encoding") + ce := resp.Header.Peek(HeaderContentEncoding) if string(ce) != "" { - t.Fatalf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath) + t.Errorf("unexpected content-encoding %q. Expecting empty string. filePath=%q", ce, filePath) } body := string(resp.Body()) - // request compressed file + // request compressed gzip file ctx.Request.Reset() ctx.Request.SetRequestURI(filePath) - ctx.Request.Header.Set("Accept-Encoding", "gzip") + ctx.Request.Header.Set(HeaderAcceptEncoding, "gzip") h(&ctx) s = ctx.Response.String() br = bufio.NewReader(bytes.NewBufferString(s)) if err := resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s. filePath=%q", err, filePath) + t.Errorf("unexpected error: %s. filePath=%q", err, filePath) } if resp.StatusCode() != StatusOK { - t.Fatalf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) + t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) } - ce = resp.Header.Peek("Content-Encoding") + ce = resp.Header.Peek(HeaderContentEncoding) if string(ce) != "gzip" { - t.Fatalf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "gzip", filePath) + t.Errorf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "gzip", filePath) } zbody, err := resp.BodyGunzip() if err != nil { - t.Fatalf("unexpected error when gunzipping response body: %s. filePath=%q", err, filePath) + t.Errorf("unexpected error when gunzipping response body: %s. filePath=%q", err, filePath) } if string(zbody) != body { - t.Fatalf("unexpected body %q. Expected %q. FilePath=%q", zbody, body, filePath) + t.Errorf("unexpected body len=%d. Expected len=%d. FilePath=%q", len(zbody), len(body), filePath) } -} -func TestFileLock(t *testing.T) { - for i := 0; i < 10; i++ { - filePath := fmt.Sprintf("foo/bar/%d.jpg", i) - lock := getFileLock(filePath) - lock.Lock() - lock.Unlock() + // request compressed brotli file + ctx.Request.Reset() + ctx.Request.SetRequestURI(filePath) + ctx.Request.Header.Set(HeaderAcceptEncoding, "br") + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err = resp.Read(br); err != nil { + t.Errorf("unexpected error: %s. filePath=%q", err, filePath) } - - for i := 0; i < 10; i++ { - filePath := fmt.Sprintf("foo/bar/%d.jpg", i) - lock := getFileLock(filePath) - lock.Lock() - lock.Unlock() + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d. filePath=%q", resp.StatusCode(), StatusOK, filePath) + } + ce = resp.Header.Peek(HeaderContentEncoding) + if string(ce) != "br" { + t.Errorf("unexpected content-encoding %q. Expecting %q. filePath=%q", ce, "br", filePath) + } + zbody, err = resp.BodyUnbrotli() + if err != nil { + t.Errorf("unexpected error when unbrotling response body: %s. filePath=%q", err, filePath) + } + if string(zbody) != body { + t.Errorf("unexpected body len=%d. Expected len=%d. FilePath=%q", len(zbody), len(body), filePath) } } func TestFSHandlerSingleThread(t *testing.T) { + // This test can't run parallel as files in / might by changed by other tests. + requestHandler := FSHandler(".", 0) f, err := os.Open(".") @@ -436,7 +626,7 @@ if err != nil { t.Fatalf("cannot read dirnames in cwd: %s", err) } - sort.Sort(sort.StringSlice(filenames)) + sort.Strings(filenames) for i := 0; i < 3; i++ { fsHandlerTest(t, requestHandler, filenames) @@ -444,6 +634,8 @@ } func TestFSHandlerConcurrent(t *testing.T) { + // This test can't run parallel as files in / might by changed by other tests. + requestHandler := FSHandler(".", 0) f, err := os.Open(".") @@ -456,7 +648,7 @@ if err != nil { t.Fatalf("cannot read dirnames in cwd: %s", err) } - sort.Sort(sort.StringSlice(filenames)) + sort.Strings(filenames) concurrency := 10 ch := make(chan struct{}, concurrency) @@ -538,6 +730,8 @@ } func TestStripPathSlashes(t *testing.T) { + t.Parallel() + testStripPathSlashes(t, "", 0, "") testStripPathSlashes(t, "", 10, "") testStripPathSlashes(t, "/", 0, "") @@ -565,6 +759,8 @@ } func TestFileExtension(t *testing.T) { + t.Parallel() + testFileExtension(t, "foo.bar", false, "zzz", ".bar") testFileExtension(t, "foobar", false, "zzz", "") testFileExtension(t, "foo.bar.baz", false, "zzz", ".baz") @@ -584,3 +780,61 @@ t.Fatalf("unexpected file extension for file %q: %q. Expecting %q", path, ext, expectedExt) } } + +func TestServeFileContentType(t *testing.T) { + // This test can't run parallel as files in / might by changed by other tests. + + var ctx RequestCtx + var req Request + req.Header.SetMethod(MethodGet) + req.SetRequestURI("http://foobar.com/baz") + ctx.Init(&req, nil, nil) + + ServeFile(&ctx, "testdata/test.png") + + var resp Response + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + expected := []byte("image/png") + if !bytes.Equal(resp.Header.ContentType(), expected) { + t.Fatalf("Unexpected Content-Type, expected: %q got %q", expected, resp.Header.ContentType()) + } +} + +func TestServeFileDirectoryRedirect(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.SkipNow() + } + + var ctx RequestCtx + var req Request + req.SetRequestURI("http://foobar.com") + ctx.Init(&req, nil, nil) + + ctx.Request.Reset() + ctx.Response.Reset() + ServeFile(&ctx, "fasthttputil") + if ctx.Response.StatusCode() != StatusFound { + t.Fatalf("Unexpected status code %d for directory '/fasthttputil' without trailing slash. Expecting %d.", ctx.Response.StatusCode(), StatusFound) + } + + ctx.Request.Reset() + ctx.Response.Reset() + ServeFile(&ctx, "fasthttputil/") + if ctx.Response.StatusCode() != StatusOK { + t.Fatalf("Unexpected status code %d for directory '/fasthttputil/' with trailing slash. Expecting %d.", ctx.Response.StatusCode(), StatusOK) + } + + ctx.Request.Reset() + ctx.Response.Reset() + ServeFile(&ctx, "fs.go") + if ctx.Response.StatusCode() != StatusOK { + t.Fatalf("Unexpected status code %d for file '/fs.go'. Expecting %d.", ctx.Response.StatusCode(), StatusOK) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/fuzzit/cookie/cookie_fuzz.go golang-github-valyala-fasthttp-1.31.0/fuzzit/cookie/cookie_fuzz.go --- golang-github-valyala-fasthttp-20160617/fuzzit/cookie/cookie_fuzz.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fuzzit/cookie/cookie_fuzz.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,26 @@ +//go:build gofuzz +// +build gofuzz + +package fuzz + +import ( + "bytes" + + "github.com/valyala/fasthttp" +) + +func Fuzz(data []byte) int { + c := fasthttp.AcquireCookie() + defer fasthttp.ReleaseCookie(c) + + if err := c.ParseBytes(data); err != nil { + return 0 + } + + w := bytes.Buffer{} + if _, err := c.WriteTo(&w); err != nil { + return 0 + } + + return 1 +} diff -Nru golang-github-valyala-fasthttp-20160617/fuzzit/request/request_fuzz.go golang-github-valyala-fasthttp-1.31.0/fuzzit/request/request_fuzz.go --- golang-github-valyala-fasthttp-20160617/fuzzit/request/request_fuzz.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fuzzit/request/request_fuzz.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,27 @@ +//go:build gofuzz +// +build gofuzz + +package fuzz + +import ( + "bufio" + "bytes" + + "github.com/valyala/fasthttp" +) + +func Fuzz(data []byte) int { + req := fasthttp.AcquireRequest() + defer fasthttp.ReleaseRequest(req) + + if err := req.ReadLimitBody(bufio.NewReader(bytes.NewReader(data)), 1024*1024); err != nil { + return 0 + } + + w := bytes.Buffer{} + if _, err := req.WriteTo(&w); err != nil { + return 0 + } + + return 1 +} diff -Nru golang-github-valyala-fasthttp-20160617/fuzzit/response/response_fuzz.go golang-github-valyala-fasthttp-1.31.0/fuzzit/response/response_fuzz.go --- golang-github-valyala-fasthttp-20160617/fuzzit/response/response_fuzz.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fuzzit/response/response_fuzz.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,27 @@ +//go:build gofuzz +// +build gofuzz + +package fuzz + +import ( + "bufio" + "bytes" + + "github.com/valyala/fasthttp" +) + +func Fuzz(data []byte) int { + res := fasthttp.AcquireResponse() + defer fasthttp.ReleaseResponse(res) + + if err := res.ReadLimitBody(bufio.NewReader(bytes.NewReader(data)), 1024*1024); err != nil { + return 0 + } + + w := bytes.Buffer{} + if _, err := res.WriteTo(&w); err != nil { + return 0 + } + + return 1 +} diff -Nru golang-github-valyala-fasthttp-20160617/fuzzit/url/url_fuzz.go golang-github-valyala-fasthttp-1.31.0/fuzzit/url/url_fuzz.go --- golang-github-valyala-fasthttp-20160617/fuzzit/url/url_fuzz.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/fuzzit/url/url_fuzz.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,24 @@ +//go:build gofuzz +// +build gofuzz + +package fuzz + +import ( + "bytes" + + "github.com/valyala/fasthttp" +) + +func Fuzz(data []byte) int { + u := fasthttp.AcquireURI() + defer fasthttp.ReleaseURI(u) + + u.UpdateBytes(data) + + w := bytes.Buffer{} + if _, err := u.WriteTo(&w); err != nil { + return 0 + } + + return 1 +} diff -Nru golang-github-valyala-fasthttp-20160617/.github/workflows/lint.yml golang-github-valyala-fasthttp-1.31.0/.github/workflows/lint.yml --- golang-github-valyala-fasthttp-20160617/.github/workflows/lint.yml 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/.github/workflows/lint.yml 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,19 @@ +name: Lint +on: + push: + branches: + - master + pull_request: +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-go@v2 + with: + go-version: 1.17.x + - run: go version + - run: diff -u <(echo -n) <(gofmt -d .) + - uses: golangci/golangci-lint-action@v2 + with: + version: v1.28.3 diff -Nru golang-github-valyala-fasthttp-20160617/.github/workflows/security.yml golang-github-valyala-fasthttp-1.31.0/.github/workflows/security.yml --- golang-github-valyala-fasthttp-20160617/.github/workflows/security.yml 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/.github/workflows/security.yml 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,22 @@ +name: Security +on: + push: + branches: + - master + pull_request: +jobs: + test: + strategy: + matrix: + go-version: [1.17.x] + platform: [ubuntu-latest] + runs-on: ${{ matrix.platform }} + steps: + - name: Install Go + uses: actions/setup-go@v1 + with: + go-version: ${{ matrix.go-version }} + - name: Checkout code + uses: actions/checkout@v2 + - name: Security + run: go get github.com/securego/gosec/cmd/gosec; `go env GOPATH`/bin/gosec -exclude=G104,G304 ./... diff -Nru golang-github-valyala-fasthttp-20160617/.github/workflows/test.yml golang-github-valyala-fasthttp-1.31.0/.github/workflows/test.yml --- golang-github-valyala-fasthttp-20160617/.github/workflows/test.yml 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/.github/workflows/test.yml 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,21 @@ +name: Test +on: + push: + branches: + - master + pull_request: +jobs: + test: + strategy: + matrix: + go-version: [1.15.x, 1.16.x, 1.17.x] + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go-version }} + - run: go version + - run: go test ./... + - run: go test -race ./... diff -Nru golang-github-valyala-fasthttp-20160617/.gitignore golang-github-valyala-fasthttp-1.31.0/.gitignore --- golang-github-valyala-fasthttp-20160617/.gitignore 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/.gitignore 2021-10-09 18:39:05.000000000 +0000 @@ -1,3 +1,7 @@ tags *.pprof *.fasthttp.gz +*.fasthttp.br +.idea +.DS_Store +vendor/ diff -Nru golang-github-valyala-fasthttp-20160617/go.mod golang-github-valyala-fasthttp-1.31.0/go.mod --- golang-github-valyala-fasthttp-20160617/go.mod 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/go.mod 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,13 @@ +module github.com/valyala/fasthttp + +go 1.12 + +require ( + github.com/andybalholm/brotli v1.0.2 + github.com/klauspost/compress v1.13.4 + github.com/valyala/bytebufferpool v1.0.0 + github.com/valyala/tcplisten v1.0.0 + golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a + golang.org/x/net v0.0.0-20210510120150-4163338589ed + golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 +) diff -Nru golang-github-valyala-fasthttp-20160617/go.sum golang-github-valyala-fasthttp-1.31.0/go.sum --- golang-github-valyala-fasthttp-20160617/go.sum 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/go.sum 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,23 @@ +github.com/andybalholm/brotli v1.0.2 h1:JKnhI/XQ75uFBTiuzXpzFrUriDPiZjlOSzh6wXogP0E= +github.com/andybalholm/brotli v1.0.2/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/klauspost/compress v1.13.4 h1:0zhec2I8zGnjWcKyLl6i3gPqKANCCn5e9xmviEEeX6s= +github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a h1:kr2P4QFmQr29mSLA43kwrOcgcReGTfbE9N577tCTuBc= +golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210510120150-4163338589ed h1:p9UgmWI9wKpfYmgaV/IZKGdXc5qEK45tDwwwDyjS26I= +golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210514084401-e8d321eab015 h1:hZR0X1kPW+nwyJ9xRxqZk1vx5RUObAPBdKVvXPDUH/E= +golang.org/x/sys v0.0.0-20210514084401-e8d321eab015/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff -Nru golang-github-valyala-fasthttp-20160617/header.go golang-github-valyala-fasthttp-1.31.0/header.go --- golang-github-valyala-fasthttp-20160617/header.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/header.go 2021-10-09 18:39:05.000000000 +0000 @@ -6,10 +6,16 @@ "errors" "fmt" "io" + "sync" "sync/atomic" "time" ) +const ( + rChar = byte('\r') + nChar = byte('\n') +) + // ResponseHeader represents HTTP response header. // // It is forbidden copying ResponseHeader instances. @@ -18,15 +24,18 @@ // ResponseHeader instance MUST NOT be used from concurrently running // goroutines. type ResponseHeader struct { - noCopy noCopy - - disableNormalizing bool - noHTTP11 bool - connectionClose bool + noCopy noCopy //nolint:unused,structcheck - statusCode int - contentLength int - contentLengthBytes []byte + disableNormalizing bool + noHTTP11 bool + connectionClose bool + noDefaultContentType bool + noDefaultDate bool + + statusCode int + contentLength int + contentLengthBytes []byte + secureErrorLogMessage bool contentType []byte server []byte @@ -45,23 +54,23 @@ // RequestHeader instance MUST NOT be used from concurrently running // goroutines. type RequestHeader struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck disableNormalizing bool noHTTP11 bool connectionClose bool - isGet bool // These two fields have been moved close to other bool fields // for reducing RequestHeader object size. cookiesCollected bool - rawHeadersParsed bool - contentLength int - contentLengthBytes []byte + contentLength int + contentLengthBytes []byte + secureErrorLogMessage bool method []byte requestURI []byte + proto []byte host []byte contentType []byte userAgent []byte @@ -71,6 +80,8 @@ cookies []argsKV + // stores an immutable copy of headers as they were received from the + // wire. rawHeaders []byte } @@ -95,8 +106,6 @@ // * If startPos is negative, then 'bytes=-startPos' value is set. // * If endPos is negative, then 'bytes=startPos-' value is set. func (h *RequestHeader) SetByteRange(startPos, endPos int) { - h.parseRawHeaders() - b := h.bufKV.value[:0] b = append(b, strBytes...) b = append(b, '=') @@ -153,25 +162,16 @@ // ConnectionClose returns true if 'Connection: close' header is set. func (h *RequestHeader) ConnectionClose() bool { - h.parseRawHeaders() - return h.connectionClose -} - -func (h *RequestHeader) connectionCloseFast() bool { - // h.parseRawHeaders() isn't called for performance reasons. - // Use ConnectionClose for triggering raw headers parsing. return h.connectionClose } // SetConnectionClose sets 'Connection: close' header. func (h *RequestHeader) SetConnectionClose() { - // h.parseRawHeaders() isn't called for performance reasons. h.connectionClose = true } // ResetConnectionClose clears 'Connection: close' header if it exists. func (h *RequestHeader) ResetConnectionClose() { - h.parseRawHeaders() if h.connectionClose { h.connectionClose = false h.h = delAllArgsBytes(h.h, strConnection) @@ -180,13 +180,17 @@ // ConnectionUpgrade returns true if 'Connection: Upgrade' header is set. func (h *ResponseHeader) ConnectionUpgrade() bool { - return hasHeaderValue(h.Peek("Connection"), strUpgrade) + return hasHeaderValue(h.Peek(HeaderConnection), strUpgrade) } // ConnectionUpgrade returns true if 'Connection: Upgrade' header is set. func (h *RequestHeader) ConnectionUpgrade() bool { - h.parseRawHeaders() - return hasHeaderValue(h.Peek("Connection"), strUpgrade) + return hasHeaderValue(h.Peek(HeaderConnection), strUpgrade) +} + +// PeekCookie is able to returns cookie by a given key from response. +func (h *ResponseHeader) PeekCookie(key string) []byte { + return peekArgStr(h.cookies, key) } // ContentLength returns Content-Length header value. @@ -218,7 +222,7 @@ h.SetConnectionClose() value = strIdentity } - h.h = setArgBytes(h.h, strTransferEncoding, value) + h.h = setArgBytes(h.h, strTransferEncoding, value, argsHasValue) } } @@ -241,10 +245,12 @@ // It may be negative: // -1 means Transfer-Encoding: chunked. func (h *RequestHeader) ContentLength() int { - if h.noBody() { - return 0 - } - h.parseRawHeaders() + return h.realContentLength() +} + +// realContentLength returns the actual Content-Length set in the request, +// including positive lengths for GET/HEAD requests. +func (h *RequestHeader) realContentLength() int { return h.contentLength } @@ -252,21 +258,30 @@ // // Negative content-length sets 'Transfer-Encoding: chunked' header. func (h *RequestHeader) SetContentLength(contentLength int) { - h.parseRawHeaders() h.contentLength = contentLength if contentLength >= 0 { h.contentLengthBytes = AppendUint(h.contentLengthBytes[:0], contentLength) h.h = delAllArgsBytes(h.h, strTransferEncoding) } else { h.contentLengthBytes = h.contentLengthBytes[:0] - h.h = setArgBytes(h.h, strTransferEncoding, strChunked) + h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue) } } +func (h *ResponseHeader) isCompressibleContentType() bool { + contentType := h.ContentType() + return bytes.HasPrefix(contentType, strTextSlash) || + bytes.HasPrefix(contentType, strApplicationSlash) || + bytes.HasPrefix(contentType, strImageSVG) || + bytes.HasPrefix(contentType, strImageIcon) || + bytes.HasPrefix(contentType, strFontSlash) || + bytes.HasPrefix(contentType, strMultipartSlash) +} + // ContentType returns Content-Type header value. func (h *ResponseHeader) ContentType() []byte { contentType := h.contentType - if len(h.contentType) == 0 { + if !h.noDefaultContentType && len(h.contentType) == 0 { contentType = defaultContentType } return contentType @@ -299,19 +314,16 @@ // ContentType returns Content-Type header value. func (h *RequestHeader) ContentType() []byte { - h.parseRawHeaders() return h.contentType } // SetContentType sets Content-Type header value. func (h *RequestHeader) SetContentType(contentType string) { - h.parseRawHeaders() h.contentType = append(h.contentType[:0], contentType...) } // SetContentTypeBytes sets Content-Type header value. func (h *RequestHeader) SetContentTypeBytes(contentType []byte) { - h.parseRawHeaders() h.contentType = append(h.contentType[:0], contentType...) } @@ -319,8 +331,6 @@ // 'multipart/form-data; boundary=...' // where ... is substituted by the given boundary. func (h *RequestHeader) SetMultipartFormBoundary(boundary string) { - h.parseRawHeaders() - b := h.bufKV.value[:0] b = append(b, strMultipartFormData...) b = append(b, ';', ' ') @@ -336,8 +346,6 @@ // 'multipart/form-data; boundary=...' // where ... is substituted by the given boundary. func (h *RequestHeader) SetMultipartFormBoundaryBytes(boundary []byte) { - h.parseRawHeaders() - b := h.bufKV.value[:0] b = append(b, strMultipartFormData...) b = append(b, ';', ' ') @@ -383,6 +391,9 @@ if n = bytes.IndexByte(b, ';'); n >= 0 { b = b[:n] } + if len(b) > 1 && b[0] == '"' && b[len(b)-1] == '"' { + b = b[1 : len(b)-1] + } return b } return nil @@ -390,50 +401,31 @@ // Host returns Host header value. func (h *RequestHeader) Host() []byte { - if len(h.host) > 0 { - return h.host - } - if !h.rawHeadersParsed { - // fast path without employing full headers parsing. - host := peekRawHeader(h.rawHeaders, strHost) - if len(host) > 0 { - h.host = append(h.host[:0], host...) - return h.host - } - } - - // slow path. - h.parseRawHeaders() return h.host } // SetHost sets Host header value. func (h *RequestHeader) SetHost(host string) { - h.parseRawHeaders() h.host = append(h.host[:0], host...) } // SetHostBytes sets Host header value. func (h *RequestHeader) SetHostBytes(host []byte) { - h.parseRawHeaders() h.host = append(h.host[:0], host...) } // UserAgent returns User-Agent header value. func (h *RequestHeader) UserAgent() []byte { - h.parseRawHeaders() return h.userAgent } // SetUserAgent sets User-Agent header value. func (h *RequestHeader) SetUserAgent(userAgent string) { - h.parseRawHeaders() h.userAgent = append(h.userAgent[:0], userAgent...) } // SetUserAgentBytes sets User-Agent header value. func (h *RequestHeader) SetUserAgentBytes(userAgent []byte) { - h.parseRawHeaders() h.userAgent = append(h.userAgent[:0], userAgent...) } @@ -455,14 +447,14 @@ // Method returns HTTP request method. func (h *RequestHeader) Method() []byte { if len(h.method) == 0 { - return strGet + return []byte(MethodGet) } return h.method } // SetMethod sets HTTP request method. func (h *RequestHeader) SetMethod(method string) { - h.method = append(h.method, method...) + h.method = append(h.method[:0], method...) } // SetMethodBytes sets HTTP request method. @@ -470,6 +462,26 @@ h.method = append(h.method[:0], method...) } +// Protocol returns HTTP protocol. +func (h *RequestHeader) Protocol() []byte { + if len(h.proto) == 0 { + return strHTTP11 + } + return h.proto +} + +// SetProtocol sets HTTP request protocol. +func (h *RequestHeader) SetProtocol(method string) { + h.proto = append(h.proto[:0], method...) + h.noHTTP11 = !bytes.Equal(h.proto, strHTTP11) +} + +// SetProtocolBytes sets HTTP request protocol. +func (h *RequestHeader) SetProtocolBytes(method []byte) { + h.proto = append(h.proto[:0], method...) + h.noHTTP11 = !bytes.Equal(h.proto, strHTTP11) +} + // RequestURI returns RequestURI from the first HTTP request line. func (h *RequestHeader) RequestURI() []byte { requestURI := h.requestURI @@ -495,35 +507,47 @@ // IsGet returns true if request method is GET. func (h *RequestHeader) IsGet() bool { - // Optimize fast path for GET requests. - if !h.isGet { - h.isGet = bytes.Equal(h.Method(), strGet) - } - return h.isGet + return string(h.Method()) == MethodGet } -// IsPost returns true if request methos is POST. +// IsPost returns true if request method is POST. func (h *RequestHeader) IsPost() bool { - return bytes.Equal(h.Method(), strPost) + return string(h.Method()) == MethodPost } // IsPut returns true if request method is PUT. func (h *RequestHeader) IsPut() bool { - return bytes.Equal(h.Method(), strPut) + return string(h.Method()) == MethodPut } // IsHead returns true if request method is HEAD. func (h *RequestHeader) IsHead() bool { - // Fast path - if h.isGet { - return false - } - return bytes.Equal(h.Method(), strHead) + return string(h.Method()) == MethodHead } // IsDelete returns true if request method is DELETE. func (h *RequestHeader) IsDelete() bool { - return bytes.Equal(h.Method(), strDelete) + return string(h.Method()) == MethodDelete +} + +// IsConnect returns true if request method is CONNECT. +func (h *RequestHeader) IsConnect() bool { + return string(h.Method()) == MethodConnect +} + +// IsOptions returns true if request method is OPTIONS. +func (h *RequestHeader) IsOptions() bool { + return string(h.Method()) == MethodOptions +} + +// IsTrace returns true if request method is TRACE. +func (h *RequestHeader) IsTrace() bool { + return string(h.Method()) == MethodTrace +} + +// IsPatch returns true if request method is PATCH. +func (h *RequestHeader) IsPatch() bool { + return string(h.Method()) == MethodPatch } // IsHTTP11 returns true if the request is HTTP/1.1. @@ -593,6 +617,22 @@ h.disableNormalizing = true } +// EnableNormalizing enables header names' normalization. +// +// Header names are normalized by uppercasing the first letter and +// all the first letters following dashes, while lowercasing all +// the other letters. +// Examples: +// +// * CONNECTION -> Connection +// * conteNT-tYPE -> Content-Type +// * foo-bar-baz -> Foo-Bar-Baz +// +// This is enabled by default unless disabled using DisableNormalizing() +func (h *RequestHeader) EnableNormalizing() { + h.disableNormalizing = false +} + // DisableNormalizing disables header names' normalization. // // By default all the header names are normalized by uppercasing @@ -609,9 +649,32 @@ h.disableNormalizing = true } +// EnableNormalizing enables header names' normalization. +// +// Header names are normalized by uppercasing the first letter and +// all the first letters following dashes, while lowercasing all +// the other letters. +// Examples: +// +// * CONNECTION -> Connection +// * conteNT-tYPE -> Content-Type +// * foo-bar-baz -> Foo-Bar-Baz +// +// This is enabled by default unless disabled using DisableNormalizing() +func (h *ResponseHeader) EnableNormalizing() { + h.disableNormalizing = false +} + +// SetNoDefaultContentType allows you to control if a default Content-Type header will be set (false) or not (true). +func (h *ResponseHeader) SetNoDefaultContentType(noDefaultContentType bool) { + h.noDefaultContentType = noDefaultContentType +} + // Reset clears response header. func (h *ResponseHeader) Reset() { h.disableNormalizing = false + h.SetNoDefaultContentType(false) + h.noDefaultDate = false h.resetSkipNormalize() } @@ -639,12 +702,12 @@ func (h *RequestHeader) resetSkipNormalize() { h.noHTTP11 = false h.connectionClose = false - h.isGet = false h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] h.method = h.method[:0] + h.proto = h.proto[:0] h.requestURI = h.requestURI[:0] h.host = h.host[:0] h.contentType = h.contentType[:0] @@ -655,7 +718,6 @@ h.cookiesCollected = false h.rawHeaders = h.rawHeaders[:0] - h.rawHeadersParsed = false } // CopyTo copies all the headers to dst. @@ -665,6 +727,8 @@ dst.disableNormalizing = h.disableNormalizing dst.noHTTP11 = h.noHTTP11 dst.connectionClose = h.connectionClose + dst.noDefaultContentType = h.noDefaultContentType + dst.noDefaultDate = h.noDefaultDate dst.statusCode = h.statusCode dst.contentLength = h.contentLength @@ -682,11 +746,11 @@ dst.disableNormalizing = h.disableNormalizing dst.noHTTP11 = h.noHTTP11 dst.connectionClose = h.connectionClose - dst.isGet = h.isGet dst.contentLength = h.contentLength dst.contentLengthBytes = append(dst.contentLengthBytes[:0], h.contentLengthBytes...) dst.method = append(dst.method[:0], h.method...) + dst.proto = append(dst.proto[:0], h.proto...) dst.requestURI = append(dst.requestURI[:0], h.requestURI...) dst.host = append(dst.host[:0], h.host...) dst.contentType = append(dst.contentType[:0], h.contentType...) @@ -695,7 +759,6 @@ dst.cookies = copyArgs(dst.cookies, h.cookies) dst.cookiesCollected = h.cookiesCollected dst.rawHeaders = append(dst.rawHeaders[:0], h.rawHeaders...) - dst.rawHeadersParsed = h.rawHeadersParsed } // VisitAll calls f for each header. @@ -740,7 +803,6 @@ // // f must not retain references to key and/or value after returning. func (h *RequestHeader) VisitAllCookie(f func(key, value []byte)) { - h.parseRawHeaders() h.collectCookies() visitArgs(h.cookies, f) } @@ -749,8 +811,9 @@ // // f must not retain references to key and/or value after returning. // Copy key and/or value contents before returning if you need retaining them. +// +// To get the headers in order they were received use VisitAllInOrder. func (h *RequestHeader) VisitAll(f func(key, value []byte)) { - h.parseRawHeaders() host := h.Host() if len(host) > 0 { f(strHost, host) @@ -778,6 +841,24 @@ } } +// VisitAllInOrder calls f for each header in the order they were received. +// +// f must not retain references to key and/or value after returning. +// Copy key and/or value contents before returning if you need retaining them. +// +// This function is slightly slower than VisitAll because it has to reparse the +// raw headers to get the order. +func (h *RequestHeader) VisitAllInOrder(f func(key, value []byte)) { + var s headerScanner + s.b = h.rawHeaders + s.disableNormalizing = h.disableNormalizing + for s.next() { + if len(s.key) > 0 { + f(s.key, s.value) + } + } +} + // Del deletes header with the given key. func (h *ResponseHeader) Del(key string) { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) @@ -793,16 +874,16 @@ func (h *ResponseHeader) del(key []byte) { switch string(key) { - case "Content-Type": + case HeaderContentType: h.contentType = h.contentType[:0] - case "Server": + case HeaderServer: h.server = h.server[:0] - case "Set-Cookie": + case HeaderSetCookie: h.cookies = h.cookies[:0] - case "Content-Length": + case HeaderContentLength: h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] - case "Connection": + case HeaderConnection: h.connectionClose = false } h.h = delAllArgsBytes(h.h, key) @@ -810,14 +891,12 @@ // Del deletes header with the given key. func (h *RequestHeader) Del(key string) { - h.parseRawHeaders() k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.del(k) } // DelBytes deletes header with the given key. func (h *RequestHeader) DelBytes(key []byte) { - h.parseRawHeaders() h.bufKV.key = append(h.bufKV.key[:0], key...) normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) h.del(h.bufKV.key) @@ -825,71 +904,205 @@ func (h *RequestHeader) del(key []byte) { switch string(key) { - case "Host": + case HeaderHost: h.host = h.host[:0] - case "Content-Type": + case HeaderContentType: h.contentType = h.contentType[:0] - case "User-Agent": + case HeaderUserAgent: h.userAgent = h.userAgent[:0] - case "Cookie": + case HeaderCookie: h.cookies = h.cookies[:0] - case "Content-Length": + case HeaderContentLength: h.contentLength = 0 h.contentLengthBytes = h.contentLengthBytes[:0] - case "Connection": + case HeaderConnection: h.connectionClose = false } h.h = delAllArgsBytes(h.h, key) } +// setSpecialHeader handles special headers and return true when a header is processed. +func (h *ResponseHeader) setSpecialHeader(key, value []byte) bool { + if len(key) == 0 { + return false + } + + switch key[0] | 0x20 { + case 'c': + if caseInsensitiveCompare(strContentType, key) { + h.SetContentTypeBytes(value) + return true + } else if caseInsensitiveCompare(strContentLength, key) { + if contentLength, err := parseContentLength(value); err == nil { + h.contentLength = contentLength + h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) + } + return true + } else if caseInsensitiveCompare(strConnection, key) { + if bytes.Equal(strClose, value) { + h.SetConnectionClose() + } else { + h.ResetConnectionClose() + h.h = setArgBytes(h.h, key, value, argsHasValue) + } + return true + } + case 's': + if caseInsensitiveCompare(strServer, key) { + h.SetServerBytes(value) + return true + } else if caseInsensitiveCompare(strSetCookie, key) { + var kv *argsKV + h.cookies, kv = allocArg(h.cookies) + kv.key = getCookieKey(kv.key, value) + kv.value = append(kv.value[:0], value...) + return true + } + case 't': + if caseInsensitiveCompare(strTransferEncoding, key) { + // Transfer-Encoding is managed automatically. + return true + } + case 'd': + if caseInsensitiveCompare(strDate, key) { + // Date is managed automatically. + return true + } + } + + return false +} + +// setSpecialHeader handles special headers and return true when a header is processed. +func (h *RequestHeader) setSpecialHeader(key, value []byte) bool { + if len(key) == 0 { + return false + } + + switch key[0] | 0x20 { + case 'c': + if caseInsensitiveCompare(strContentType, key) { + h.SetContentTypeBytes(value) + return true + } else if caseInsensitiveCompare(strContentLength, key) { + if contentLength, err := parseContentLength(value); err == nil { + h.contentLength = contentLength + h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) + } + return true + } else if caseInsensitiveCompare(strConnection, key) { + if bytes.Equal(strClose, value) { + h.SetConnectionClose() + } else { + h.ResetConnectionClose() + h.h = setArgBytes(h.h, key, value, argsHasValue) + } + return true + } else if caseInsensitiveCompare(strCookie, key) { + h.collectCookies() + h.cookies = parseRequestCookies(h.cookies, value) + return true + } + case 't': + if caseInsensitiveCompare(strTransferEncoding, key) { + // Transfer-Encoding is managed automatically. + return true + } + case 'h': + if caseInsensitiveCompare(strHost, key) { + h.SetHostBytes(value) + return true + } + case 'u': + if caseInsensitiveCompare(strUserAgent, key) { + h.SetUserAgentBytes(value) + return true + } + } + + return false +} + // Add adds the given 'key: value' header. // -// Multiple headers with the same key may be added. +// Multiple headers with the same key may be added with this function. +// Use Set for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Server, Set-Cookie, +// Transfer-Encoding and Date headers can only be set once and will +// overwrite the previous value. func (h *ResponseHeader) Add(key, value string) { - k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) - h.h = appendArg(h.h, b2s(k), value) + h.AddBytesKV(s2b(key), s2b(value)) } // AddBytesK adds the given 'key: value' header. // -// Multiple headers with the same key may be added. +// Multiple headers with the same key may be added with this function. +// Use SetBytesK for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Server, Set-Cookie, +// Transfer-Encoding and Date headers can only be set once and will +// overwrite the previous value. func (h *ResponseHeader) AddBytesK(key []byte, value string) { - h.Add(b2s(key), value) + h.AddBytesKV(key, s2b(value)) } // AddBytesV adds the given 'key: value' header. // -// Multiple headers with the same key may be added. +// Multiple headers with the same key may be added with this function. +// Use SetBytesV for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Server, Set-Cookie, +// Transfer-Encoding and Date headers can only be set once and will +// overwrite the previous value. func (h *ResponseHeader) AddBytesV(key string, value []byte) { - h.Add(key, b2s(value)) + h.AddBytesKV(s2b(key), value) } // AddBytesKV adds the given 'key: value' header. // -// Multiple headers with the same key may be added. +// Multiple headers with the same key may be added with this function. +// Use SetBytesKV for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Server, Set-Cookie, +// Transfer-Encoding and Date headers can only be set once and will +// overwrite the previous value. func (h *ResponseHeader) AddBytesKV(key, value []byte) { - h.Add(b2s(key), b2s(value)) + if h.setSpecialHeader(key, value) { + return + } + + k := getHeaderKeyBytes(&h.bufKV, b2s(key), h.disableNormalizing) + h.h = appendArgBytes(h.h, k, value, argsHasValue) } // Set sets the given 'key: value' header. +// +// Use Add for setting multiple header values under the same key. func (h *ResponseHeader) Set(key, value string) { initHeaderKV(&h.bufKV, key, value, h.disableNormalizing) h.SetCanonical(h.bufKV.key, h.bufKV.value) } // SetBytesK sets the given 'key: value' header. +// +// Use AddBytesK for setting multiple header values under the same key. func (h *ResponseHeader) SetBytesK(key []byte, value string) { h.bufKV.value = append(h.bufKV.value[:0], value...) h.SetBytesKV(key, h.bufKV.value) } // SetBytesV sets the given 'key: value' header. +// +// Use AddBytesV for setting multiple header values under the same key. func (h *ResponseHeader) SetBytesV(key string, value []byte) { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.SetCanonical(k, value) } // SetBytesKV sets the given 'key: value' header. +// +// Use AddBytesKV for setting multiple header values under the same key. func (h *ResponseHeader) SetBytesKV(key, value []byte) { h.bufKV.key = append(h.bufKV.key[:0], key...) normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) @@ -899,47 +1112,24 @@ // SetCanonical sets the given 'key: value' header assuming that // key is in canonical form. func (h *ResponseHeader) SetCanonical(key, value []byte) { - switch string(key) { - case "Content-Type": - h.SetContentTypeBytes(value) - case "Server": - h.SetServerBytes(value) - case "Set-Cookie": - var kv *argsKV - h.cookies, kv = allocArg(h.cookies) - kv.key = getCookieKey(kv.key, value) - kv.value = append(kv.value[:0], value...) - case "Content-Length": - if contentLength, err := parseContentLength(value); err == nil { - h.contentLength = contentLength - h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) - } - case "Connection": - if bytes.Equal(strClose, value) { - h.SetConnectionClose() - } else { - h.ResetConnectionClose() - h.h = setArgBytes(h.h, key, value) - } - case "Transfer-Encoding": - // Transfer-Encoding is managed automatically. - case "Date": - // Date is managed automatically. - default: - h.h = setArgBytes(h.h, key, value) + if h.setSpecialHeader(key, value) { + return } + + h.h = setArgBytes(h.h, key, value, argsHasValue) } // SetCookie sets the given response cookie. +// +// It is save re-using the cookie after the function returns. func (h *ResponseHeader) SetCookie(cookie *Cookie) { - h.cookies = setArgBytes(h.cookies, cookie.Key(), cookie.Cookie()) + h.cookies = setArgBytes(h.cookies, cookie.Key(), cookie.Cookie(), argsHasValue) } // SetCookie sets 'key: value' cookies. func (h *RequestHeader) SetCookie(key, value string) { - h.parseRawHeaders() h.collectCookies() - h.cookies = setArg(h.cookies, key, value) + h.cookies = setArg(h.cookies, key, value, argsHasValue) } // SetCookieBytesK sets 'key: value' cookies. @@ -953,6 +1143,16 @@ } // DelClientCookie instructs the client to remove the given cookie. +// This doesn't work for a cookie with specific domain or path, +// you should delete it manually like: +// +// c := AcquireCookie() +// c.SetKey(key) +// c.SetDomain("example.com") +// c.SetPath("/path") +// c.SetExpire(CookieExpireDelete) +// h.SetCookie(c) +// ReleaseCookie(c) // // Use DelCookie if you want just removing the cookie from response header. func (h *ResponseHeader) DelClientCookie(key string) { @@ -966,6 +1166,16 @@ } // DelClientCookieBytes instructs the client to remove the given cookie. +// This doesn't work for a cookie with specific domain or path, +// you should delete it manually like: +// +// c := AcquireCookie() +// c.SetKey(key) +// c.SetDomain("example.com") +// c.SetPath("/path") +// c.SetExpire(CookieExpireDelete) +// h.SetCookie(c) +// ReleaseCookie(c) // // Use DelCookieBytes if you want just removing the cookie from response header. func (h *ResponseHeader) DelClientCookieBytes(key []byte) { @@ -990,7 +1200,6 @@ // DelCookie removes cookie under the given key. func (h *RequestHeader) DelCookie(key string) { - h.parseRawHeaders() h.collectCookies() h.cookies = delAllArgs(h.cookies, key) } @@ -1007,59 +1216,78 @@ // DelAllCookies removes all the cookies from request headers. func (h *RequestHeader) DelAllCookies() { - h.parseRawHeaders() h.collectCookies() h.cookies = h.cookies[:0] } // Add adds the given 'key: value' header. // -// Multiple headers with the same key may be added. +// Multiple headers with the same key may be added with this function. +// Use Set for setting a single header for the given key. func (h *RequestHeader) Add(key, value string) { - k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) - h.h = appendArg(h.h, b2s(k), value) + h.AddBytesKV(s2b(key), s2b(value)) } // AddBytesK adds the given 'key: value' header. // -// Multiple headers with the same key may be added. +// Multiple headers with the same key may be added with this function. +// Use SetBytesK for setting a single header for the given key. func (h *RequestHeader) AddBytesK(key []byte, value string) { - h.Add(b2s(key), value) + h.AddBytesKV(key, s2b(value)) } // AddBytesV adds the given 'key: value' header. // -// Multiple headers with the same key may be added. +// Multiple headers with the same key may be added with this function. +// Use SetBytesV for setting a single header for the given key. func (h *RequestHeader) AddBytesV(key string, value []byte) { - h.Add(key, b2s(value)) + h.AddBytesKV(s2b(key), value) } // AddBytesKV adds the given 'key: value' header. // -// Multiple headers with the same key may be added. +// Multiple headers with the same key may be added with this function. +// Use SetBytesKV for setting a single header for the given key. +// +// the Content-Type, Content-Length, Connection, Cookie, +// Transfer-Encoding, Host and User-Agent headers can only be set once +// and will overwrite the previous value. func (h *RequestHeader) AddBytesKV(key, value []byte) { - h.Add(b2s(key), b2s(value)) + if h.setSpecialHeader(key, value) { + return + } + + k := getHeaderKeyBytes(&h.bufKV, b2s(key), h.disableNormalizing) + h.h = appendArgBytes(h.h, k, value, argsHasValue) } // Set sets the given 'key: value' header. +// +// Use Add for setting multiple header values under the same key. func (h *RequestHeader) Set(key, value string) { initHeaderKV(&h.bufKV, key, value, h.disableNormalizing) h.SetCanonical(h.bufKV.key, h.bufKV.value) } // SetBytesK sets the given 'key: value' header. +// +// Use AddBytesK for setting multiple header values under the same key. func (h *RequestHeader) SetBytesK(key []byte, value string) { h.bufKV.value = append(h.bufKV.value[:0], value...) h.SetBytesKV(key, h.bufKV.value) } // SetBytesV sets the given 'key: value' header. +// +// Use AddBytesV for setting multiple header values under the same key. func (h *RequestHeader) SetBytesV(key string, value []byte) { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) h.SetCanonical(k, value) } // SetBytesKV sets the given 'key: value' header. +// +// Use AddBytesKV for setting multiple header values under the same key. func (h *RequestHeader) SetBytesKV(key, value []byte) { h.bufKV.key = append(h.bufKV.key[:0], key...) normalizeHeaderKey(h.bufKV.key, h.disableNormalizing) @@ -1069,40 +1297,18 @@ // SetCanonical sets the given 'key: value' header assuming that // key is in canonical form. func (h *RequestHeader) SetCanonical(key, value []byte) { - h.parseRawHeaders() - switch string(key) { - case "Host": - h.SetHostBytes(value) - case "Content-Type": - h.SetContentTypeBytes(value) - case "User-Agent": - h.SetUserAgentBytes(value) - case "Cookie": - h.collectCookies() - h.cookies = parseRequestCookies(h.cookies, value) - case "Content-Length": - if contentLength, err := parseContentLength(value); err == nil { - h.contentLength = contentLength - h.contentLengthBytes = append(h.contentLengthBytes[:0], value...) - } - case "Connection": - if bytes.Equal(strClose, value) { - h.SetConnectionClose() - } else { - h.ResetConnectionClose() - h.h = setArgBytes(h.h, key, value) - } - case "Transfer-Encoding": - // Transfer-Encoding is managed automatically. - default: - h.h = setArgBytes(h.h, key, value) + if h.setSpecialHeader(key, value) { + return } + + h.h = setArgBytes(h.h, key, value, argsHasValue) } // Peek returns header value for the given key. // -// Returned value is valid until the next call to ResponseHeader. -// Do not store references to returned value. Make copies instead. +// The returned value is valid until the response is released, +// either though ReleaseResponse or your request handler returning. +// Do not store references to the returned value. Make copies instead. func (h *ResponseHeader) Peek(key string) []byte { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) return h.peek(k) @@ -1110,7 +1316,8 @@ // PeekBytes returns header value for the given key. // -// Returned value is valid until the next call to ResponseHeader. +// The returned value is valid until the response is released, +// either though ReleaseResponse or your request handler returning. // Do not store references to returned value. Make copies instead. func (h *ResponseHeader) PeekBytes(key []byte) []byte { h.bufKV.key = append(h.bufKV.key[:0], key...) @@ -1120,7 +1327,8 @@ // Peek returns header value for the given key. // -// Returned value is valid until the next call to RequestHeader. +// The returned value is valid until the request is released, +// either though ReleaseRequest or your request handler returning. // Do not store references to returned value. Make copies instead. func (h *RequestHeader) Peek(key string) []byte { k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) @@ -1129,7 +1337,8 @@ // PeekBytes returns header value for the given key. // -// Returned value is valid until the next call to RequestHeader. +// The returned value is valid until the request is released, +// either though ReleaseRequest or your request handler returning. // Do not store references to returned value. Make copies instead. func (h *RequestHeader) PeekBytes(key []byte) []byte { h.bufKV.key = append(h.bufKV.key[:0], key...) @@ -1139,38 +1348,44 @@ func (h *ResponseHeader) peek(key []byte) []byte { switch string(key) { - case "Content-Type": + case HeaderContentType: return h.ContentType() - case "Server": + case HeaderServer: return h.Server() - case "Connection": + case HeaderConnection: if h.ConnectionClose() { return strClose } return peekArgBytes(h.h, key) - case "Content-Length": + case HeaderContentLength: return h.contentLengthBytes + case HeaderSetCookie: + return appendResponseCookieBytes(nil, h.cookies) default: return peekArgBytes(h.h, key) } } func (h *RequestHeader) peek(key []byte) []byte { - h.parseRawHeaders() switch string(key) { - case "Host": + case HeaderHost: return h.Host() - case "Content-Type": + case HeaderContentType: return h.ContentType() - case "User-Agent": + case HeaderUserAgent: return h.UserAgent() - case "Connection": + case HeaderConnection: if h.ConnectionClose() { return strClose } return peekArgBytes(h.h, key) - case "Content-Length": + case HeaderContentLength: return h.contentLengthBytes + case HeaderCookie: + if h.cookiesCollected { + return appendRequestCookieBytes(nil, h.cookies) + } + return peekArgBytes(h.h, key) default: return peekArgBytes(h.h, key) } @@ -1178,14 +1393,12 @@ // Cookie returns cookie for the given key. func (h *RequestHeader) Cookie(key string) []byte { - h.parseRawHeaders() h.collectCookies() return peekArgStr(h.cookies, key) } // CookieBytes returns cookie for the given key. func (h *RequestHeader) CookieBytes(key []byte) []byte { - h.parseRawHeaders() h.collectCookies() return peekArgBytes(h.cookies, key) } @@ -1198,7 +1411,7 @@ if v == nil { return false } - cookie.ParseBytes(v) + cookie.ParseBytes(v) //nolint:errcheck return true } @@ -1224,48 +1437,85 @@ h.resetSkipNormalize() b, err := r.Peek(n) if len(b) == 0 { - // treat all errors on the first byte read as EOF + // Return ErrTimeout on any timeout. + if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() { + return ErrTimeout + } + // treat all other errors on the first byte read as EOF if n == 1 || err == io.EOF { return io.EOF } + + // This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 . if err == bufio.ErrBufferFull { - err = bufferFullError(r) + if h.secureErrorLogMessage { + return &ErrSmallBuffer{ + error: fmt.Errorf("error when reading response headers"), + } + } + return &ErrSmallBuffer{ + error: fmt.Errorf("error when reading response headers: %s", errSmallBuffer), + } } + return fmt.Errorf("error when reading response headers: %s", err) } - isEOF := (err != nil) b = mustPeekBuffered(r) - var headersLen int - if headersLen, err = h.parse(b); err != nil { - if err == errNeedMore { - if !isEOF { - return err - } - - // Buggy servers may leave trailing CRLFs after response body. - // Treat this case as EOF. - if isOnlyCRLF(b) { - return io.EOF - } - } - bStart, bEnd := bufferStartEnd(b) - return fmt.Errorf("error when reading response headers: %s. buf=%q...%q", err, bStart, bEnd) + headersLen, errParse := h.parse(b) + if errParse != nil { + return headerError("response", err, errParse, b, h.secureErrorLogMessage) } mustDiscard(r, headersLen) return nil } +func headerError(typ string, err, errParse error, b []byte, secureErrorLogMessage bool) error { + if errParse != errNeedMore { + return headerErrorMsg(typ, errParse, b, secureErrorLogMessage) + } + if err == nil { + return errNeedMore + } + + // Buggy servers may leave trailing CRLFs after http body. + // Treat this case as EOF. + if isOnlyCRLF(b) { + return io.EOF + } + + if err != bufio.ErrBufferFull { + return headerErrorMsg(typ, err, b, secureErrorLogMessage) + } + return &ErrSmallBuffer{ + error: headerErrorMsg(typ, errSmallBuffer, b, secureErrorLogMessage), + } +} + +func headerErrorMsg(typ string, err error, b []byte, secureErrorLogMessage bool) error { + if secureErrorLogMessage { + return fmt.Errorf("error when reading %s headers: %s. Buffer size=%d", typ, err, len(b)) + } + return fmt.Errorf("error when reading %s headers: %s. Buffer size=%d, contents: %s", typ, err, len(b), bufferSnippet(b)) +} + // Read reads request header from r. // // io.EOF is returned if r is closed before reading the first header byte. func (h *RequestHeader) Read(r *bufio.Reader) error { + return h.readLoop(r, true) +} + +// readLoop reads request header from r optionally loops until it has enough data. +// +// io.EOF is returned if r is closed before reading the first header byte. +func (h *RequestHeader) readLoop(r *bufio.Reader, waitForMore bool) error { n := 1 for { err := h.tryRead(r, n) if err == nil { return nil } - if err != errNeedMore { + if !waitForMore || err != errNeedMore { h.resetSkipNormalize() return err } @@ -1277,48 +1527,39 @@ h.resetSkipNormalize() b, err := r.Peek(n) if len(b) == 0 { - // treat all errors on the first byte read as EOF - if n == 1 || err == io.EOF { - return io.EOF + if err == io.EOF { + return err } + + if err == nil { + panic("bufio.Reader.Peek() returned nil, nil") + } + + // This is for go 1.6 bug. See https://github.com/golang/go/issues/14121 . if err == bufio.ErrBufferFull { - err = bufferFullError(r) + return &ErrSmallBuffer{ + error: fmt.Errorf("error when reading request headers: %s", errSmallBuffer), + } + } + + // n == 1 on the first read for the request. + if n == 1 { + // We didn't read a single byte. + return ErrNothingRead{err} } + return fmt.Errorf("error when reading request headers: %s", err) } - isEOF := (err != nil) b = mustPeekBuffered(r) - var headersLen int - if headersLen, err = h.parse(b); err != nil { - if err == errNeedMore { - if !isEOF { - return err - } - - // Buggy clients may leave trailing CRLFs after the request body. - // Treat this case as EOF. - if isOnlyCRLF(b) { - return io.EOF - } - } - bStart, bEnd := bufferStartEnd(b) - return fmt.Errorf("error when reading request headers: %s. buf=%q...%q", err, bStart, bEnd) + headersLen, errParse := h.parse(b) + if errParse != nil { + return headerError("request", err, errParse, b, h.secureErrorLogMessage) } mustDiscard(r, headersLen) return nil } -func bufferFullError(r *bufio.Reader) error { - n := r.Buffered() - b, err := r.Peek(n) - if err != nil { - panic(fmt.Sprintf("BUG: unexpected error returned from bufio.Reader.Peek(Buffered()): %s", err)) - } - bStart, bEnd := bufferStartEnd(b) - return fmt.Errorf("headers exceed %d bytes. Increase ReadBufferSize. buf=%q...%q", n, bStart, bEnd) -} - -func bufferStartEnd(b []byte) ([]byte, []byte) { +func bufferSnippet(b []byte) string { n := len(b) start := 200 end := n - start @@ -1326,19 +1567,23 @@ start = n end = n } - return b[:start], b[end:] + bStart, bEnd := b[:start], b[end:] + if len(bEnd) == 0 { + return fmt.Sprintf("%q", b) + } + return fmt.Sprintf("%q...%q", bStart, bEnd) } func isOnlyCRLF(b []byte) bool { for _, ch := range b { - if ch != '\r' && ch != '\n' { + if ch != rChar && ch != nChar { return false } } return true } -func init() { +func updateServerDate() { refreshServerDate() go func() { for { @@ -1348,7 +1593,10 @@ }() } -var serverDate atomic.Value +var ( + serverDate atomic.Value + serverDateOnce sync.Once // serverDateOnce.Do(updateServerDate) +) func refreshServerDate() { b := AppendHTTPDate(nil, time.Now()) @@ -1371,7 +1619,9 @@ // Header returns response header representation. // -// The returned value is valid until the next call to ResponseHeader methods. +// The returned value is valid until the request is released, +// either though ReleaseRequest or your request handler returning. +// Do not store references to returned value. Make copies instead. func (h *ResponseHeader) Header() []byte { h.bufKV.value = h.AppendBytes(h.bufKV.value[:0]) return h.bufKV.value @@ -1392,17 +1642,23 @@ dst = append(dst, statusLine(statusCode)...) server := h.Server() - if len(server) == 0 { - server = defaultServerName + if len(server) != 0 { + dst = appendHeaderLine(dst, strServer, server) + } + + if !h.noDefaultDate { + serverDateOnce.Do(updateServerDate) + dst = appendHeaderLine(dst, strDate, serverDate.Load().([]byte)) } - dst = appendHeaderLine(dst, strServer, server) - dst = appendHeaderLine(dst, strDate, serverDate.Load().([]byte)) // Append Content-Type only for non-zero responses // or if it is explicitly set. // See https://github.com/valyala/fasthttp/issues/28 . if h.ContentLength() != 0 || len(h.contentType) > 0 { - dst = appendHeaderLine(dst, strContentType, h.ContentType()) + contentType := h.ContentType() + if len(contentType) > 0 { + dst = appendHeaderLine(dst, strContentType, contentType) + } } if len(h.contentLengthBytes) > 0 { @@ -1411,7 +1667,7 @@ for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] - if !bytes.Equal(kv.key, strDate) { + if h.noDefaultDate || !bytes.Equal(kv.key, strDate) { dst = appendHeaderLine(dst, kv.key, kv.value) } } @@ -1447,12 +1703,28 @@ // Header returns request header representation. // -// The returned representation is valid until the next call to RequestHeader methods. +// The returned value is valid until the request is released, +// either though ReleaseRequest or your request handler returning. +// Do not store references to returned value. Make copies instead. func (h *RequestHeader) Header() []byte { h.bufKV.value = h.AppendBytes(h.bufKV.value[:0]) return h.bufKV.value } +// RawHeaders returns raw header key/value bytes. +// +// Depending on server configuration, header keys may be normalized to +// capital-case in place. +// +// This copy is set aside during parsing, so empty slice is returned for all +// cases where parsing did not happen. Similarly, request line is not stored +// during parsing and can not be returned. +// +// The slice is not safe to use after the handler returns. +func (h *RequestHeader) RawHeaders() []byte { + return h.rawHeaders +} + // String returns request header representation. func (h *RequestHeader) String() string { return string(h.Header()) @@ -1461,23 +1733,17 @@ // AppendBytes appends request header representation to dst and returns // the extended dst. func (h *RequestHeader) AppendBytes(dst []byte) []byte { - // there is no need in h.parseRawHeaders() here - raw headers are specially handled below. dst = append(dst, h.Method()...) dst = append(dst, ' ') dst = append(dst, h.RequestURI()...) dst = append(dst, ' ') - dst = append(dst, strHTTP11...) + dst = append(dst, h.Protocol()...) dst = append(dst, strCRLF...) - if !h.rawHeadersParsed && len(h.rawHeaders) > 0 { - return append(dst, h.rawHeaders...) - } - userAgent := h.UserAgent() - if len(userAgent) == 0 { - userAgent = defaultUserAgent + if len(userAgent) > 0 { + dst = appendHeaderLine(dst, strUserAgent, userAgent) } - dst = appendHeaderLine(dst, strUserAgent, userAgent) host := h.Host() if len(host) > 0 { @@ -1485,18 +1751,15 @@ } contentType := h.ContentType() - if !h.noBody() { - if len(contentType) == 0 { - contentType = strPostArgsContentType - } - dst = appendHeaderLine(dst, strContentType, contentType) - - if len(h.contentLengthBytes) > 0 { - dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes) - } - } else if len(contentType) > 0 { + if len(contentType) == 0 && !h.ignoreBody() { + contentType = strDefaultContentType + } + if len(contentType) > 0 { dst = appendHeaderLine(dst, strContentType, contentType) } + if len(h.contentLengthBytes) > 0 { + dst = appendHeaderLine(dst, strContentLength, h.contentLengthBytes) + } for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] @@ -1539,7 +1802,7 @@ return m + n, nil } -func (h *RequestHeader) noBody() bool { +func (h *RequestHeader) ignoreBody() bool { return h.IsGet() || h.IsHead() } @@ -1549,20 +1812,14 @@ return 0, err } + h.rawHeaders, _, err = readRawHeaders(h.rawHeaders[:0], buf[m:]) + if err != nil { + return 0, err + } var n int - if !h.noBody() || h.noHTTP11 { - n, err = h.parseHeaders(buf[m:]) - if err != nil { - return 0, err - } - h.rawHeadersParsed = true - } else { - var rawHeaders []byte - rawHeaders, n, err = readRawHeaders(h.rawHeaders[:0], buf[m:]) - if err != nil { - return 0, err - } - h.rawHeaders = rawHeaders + n, err = h.parseHeaders(buf[m:]) + if err != nil { + return 0, err } return m + n, nil } @@ -1580,6 +1837,9 @@ // parse protocol n := bytes.IndexByte(b, ' ') if n < 0 { + if h.secureErrorLogMessage { + return 0, fmt.Errorf("cannot find whitespace in the first line of response") + } return 0, fmt.Errorf("cannot find whitespace in the first line of response %q", buf) } h.noHTTP11 = !bytes.Equal(b[:n], strHTTP11) @@ -1588,9 +1848,15 @@ // parse status code h.statusCode, n, err = parseUintBuf(b) if err != nil { + if h.secureErrorLogMessage { + return 0, fmt.Errorf("cannot parse response status code: %s", err) + } return 0, fmt.Errorf("cannot parse response status code: %s. Response %q", err, buf) } if len(b) > n && b[n] != ' ' { + if h.secureErrorLogMessage { + return 0, fmt.Errorf("unexpected char at the end of status code") + } return 0, fmt.Errorf("unexpected char at the end of status code. Response %q", buf) } @@ -1610,63 +1876,43 @@ // parse method n := bytes.IndexByte(b, ' ') if n <= 0 { + if h.secureErrorLogMessage { + return 0, fmt.Errorf("cannot find http request method") + } return 0, fmt.Errorf("cannot find http request method in %q", buf) } h.method = append(h.method[:0], b[:n]...) b = b[n+1:] + protoStr := strHTTP11 // parse requestURI n = bytes.LastIndexByte(b, ' ') if n < 0 { h.noHTTP11 = true n = len(b) + protoStr = strHTTP10 } else if n == 0 { + if h.secureErrorLogMessage { + return 0, fmt.Errorf("requestURI cannot be empty") + } return 0, fmt.Errorf("requestURI cannot be empty in %q", buf) } else if !bytes.Equal(b[n+1:], strHTTP11) { h.noHTTP11 = true + protoStr = b[n+1:] } + + h.proto = append(h.proto[:0], protoStr...) h.requestURI = append(h.requestURI[:0], b[:n]...) return len(buf) - len(bNext), nil } -func peekRawHeader(buf, key []byte) []byte { - n := bytes.Index(buf, key) - if n < 0 { - return nil - } - if n > 0 && buf[n-1] != '\n' { - return nil - } - n += len(key) - if n >= len(buf) { - return nil - } - if buf[n] != ':' { - return nil - } - n++ - if buf[n] != ' ' { - return nil - } - n++ - buf = buf[n:] - n = bytes.IndexByte(buf, '\n') - if n < 0 { - return nil - } - if n > 0 && buf[n-1] == '\r' { - n-- - } - return buf[:n] -} - func readRawHeaders(dst, buf []byte) ([]byte, int, error) { - n := bytes.IndexByte(buf, '\n') + n := bytes.IndexByte(buf, nChar) if n < 0 { - return nil, 0, errNeedMore + return dst[:0], 0, errNeedMore } - if (n == 1 && buf[0] == '\r') || n == 0 { + if (n == 1 && buf[0] == rChar) || n == 0 { // empty headers return dst, n + 1, nil } @@ -1676,13 +1922,13 @@ m := n for { b = b[m:] - m = bytes.IndexByte(b, '\n') + m = bytes.IndexByte(b, nChar) if m < 0 { - return nil, 0, errNeedMore + return dst, 0, errNeedMore } m++ n += m - if (m == 2 && b[0] == '\r') || m == 1 { + if (m == 2 && b[0] == rChar) || m == 1 { dst = append(dst, buf[:n]...) return dst, n, nil } @@ -1699,37 +1945,53 @@ var err error var kv *argsKV for s.next() { - switch string(s.key) { - case "Content-Type": - h.contentType = append(h.contentType[:0], s.value...) - case "Server": - h.server = append(h.server[:0], s.value...) - case "Content-Length": - if h.contentLength != -1 { - if h.contentLength, err = parseContentLength(s.value); err != nil { - h.contentLength = -2 - } else { - h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) + if len(s.key) > 0 { + switch s.key[0] | 0x20 { + case 'c': + if caseInsensitiveCompare(s.key, strContentType) { + h.contentType = append(h.contentType[:0], s.value...) + continue + } + if caseInsensitiveCompare(s.key, strContentLength) { + if h.contentLength != -1 { + if h.contentLength, err = parseContentLength(s.value); err != nil { + h.contentLength = -2 + } else { + h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) + } + } + continue + } + if caseInsensitiveCompare(s.key, strConnection) { + if bytes.Equal(s.value, strClose) { + h.connectionClose = true + } else { + h.connectionClose = false + h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + } + continue + } + case 's': + if caseInsensitiveCompare(s.key, strServer) { + h.server = append(h.server[:0], s.value...) + continue + } + if caseInsensitiveCompare(s.key, strSetCookie) { + h.cookies, kv = allocArg(h.cookies) + kv.key = getCookieKey(kv.key, s.value) + kv.value = append(kv.value[:0], s.value...) + continue + } + case 't': + if caseInsensitiveCompare(s.key, strTransferEncoding) { + if len(s.value) > 0 && !bytes.Equal(s.value, strIdentity) { + h.contentLength = -1 + h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue) + } + continue } } - case "Transfer-Encoding": - if !bytes.Equal(s.value, strIdentity) { - h.contentLength = -1 - h.h = setArgBytes(h.h, strTransferEncoding, strChunked) - } - case "Set-Cookie": - h.cookies, kv = allocArg(h.cookies) - kv.key = getCookieKey(kv.key, s.value) - kv.value = append(kv.value[:0], s.value...) - case "Connection": - if bytes.Equal(s.value, strClose) { - h.connectionClose = true - } else { - h.connectionClose = false - h.h = appendArgBytes(h.h, s.key, s.value) - } - default: - h.h = appendArgBytes(h.h, s.key, s.value) + h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) } } if s.err != nil { @@ -1741,13 +2003,13 @@ h.contentLengthBytes = h.contentLengthBytes[:0] } if h.contentLength == -2 && !h.ConnectionUpgrade() && !h.mustSkipContentLength() { - h.h = setArgBytes(h.h, strTransferEncoding, strIdentity) + h.h = setArgBytes(h.h, strTransferEncoding, strIdentity, argsHasValue) h.connectionClose = true } if h.noHTTP11 && !h.connectionClose { // close connection for non-http/1.1 response unless 'Connection: keep-alive' is set. v := peekArgBytes(h.h, strConnection) - h.connectionClose = !hasHeaderValue(v, strKeepAlive) && !hasHeaderValue(v, strKeepAliveCamelCase) + h.connectionClose = !hasHeaderValue(v, strKeepAlive) } return len(buf) - len(s.b), nil @@ -1761,67 +2023,82 @@ s.disableNormalizing = h.disableNormalizing var err error for s.next() { - switch string(s.key) { - case "Host": - h.host = append(h.host[:0], s.value...) - case "User-Agent": - h.userAgent = append(h.userAgent[:0], s.value...) - case "Content-Type": - h.contentType = append(h.contentType[:0], s.value...) - case "Content-Length": - if h.contentLength != -1 { - if h.contentLength, err = parseContentLength(s.value); err != nil { - h.contentLength = -2 - } else { - h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) - } + if len(s.key) > 0 { + // Spaces between the header key and colon are not allowed. + // See RFC 7230, Section 3.2.4. + if bytes.IndexByte(s.key, ' ') != -1 || bytes.IndexByte(s.key, '\t') != -1 { + err = fmt.Errorf("invalid header key %q", s.key) + continue } - case "Transfer-Encoding": - if !bytes.Equal(s.value, strIdentity) { - h.contentLength = -1 - h.h = setArgBytes(h.h, strTransferEncoding, strChunked) - } - case "Connection": - if bytes.Equal(s.value, strClose) { - h.connectionClose = true - } else { - h.connectionClose = false - h.h = appendArgBytes(h.h, s.key, s.value) + + switch s.key[0] | 0x20 { + case 'h': + if caseInsensitiveCompare(s.key, strHost) { + h.host = append(h.host[:0], s.value...) + continue + } + case 'u': + if caseInsensitiveCompare(s.key, strUserAgent) { + h.userAgent = append(h.userAgent[:0], s.value...) + continue + } + case 'c': + if caseInsensitiveCompare(s.key, strContentType) { + h.contentType = append(h.contentType[:0], s.value...) + continue + } + if caseInsensitiveCompare(s.key, strContentLength) { + if h.contentLength != -1 { + var nerr error + if h.contentLength, nerr = parseContentLength(s.value); nerr != nil { + if err == nil { + err = nerr + } + h.contentLength = -2 + } else { + h.contentLengthBytes = append(h.contentLengthBytes[:0], s.value...) + } + } + continue + } + if caseInsensitiveCompare(s.key, strConnection) { + if bytes.Equal(s.value, strClose) { + h.connectionClose = true + } else { + h.connectionClose = false + h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) + } + continue + } + case 't': + if caseInsensitiveCompare(s.key, strTransferEncoding) { + if !bytes.Equal(s.value, strIdentity) { + h.contentLength = -1 + h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue) + } + continue + } } - default: - h.h = appendArgBytes(h.h, s.key, s.value) } + h.h = appendArgBytes(h.h, s.key, s.value, argsHasValue) } - if s.err != nil { + if s.err != nil && err == nil { + err = s.err + } + if err != nil { h.connectionClose = true - return 0, s.err + return 0, err } if h.contentLength < 0 { h.contentLengthBytes = h.contentLengthBytes[:0] } - if h.noBody() { - h.contentLength = 0 - h.contentLengthBytes = h.contentLengthBytes[:0] - } if h.noHTTP11 && !h.connectionClose { // close connection for non-http/1.1 request unless 'Connection: keep-alive' is set. v := peekArgBytes(h.h, strConnection) - h.connectionClose = !hasHeaderValue(v, strKeepAlive) && !hasHeaderValue(v, strKeepAliveCamelCase) - } - - return len(buf) - len(s.b), nil -} - -func (h *RequestHeader) parseRawHeaders() { - if h.rawHeadersParsed { - return - } - h.rawHeadersParsed = true - if len(h.rawHeaders) == 0 { - return + h.connectionClose = !hasHeaderValue(v, strKeepAlive) } - h.parseHeaders(h.rawHeaders) + return s.hLen, nil } func (h *RequestHeader) collectCookies() { @@ -1831,7 +2108,7 @@ for i, n := 0, len(h.h); i < n; i++ { kv := &h.h[i] - if bytes.Equal(kv.key, strCookie) { + if caseInsensitiveCompare(kv.key, strCookie) { h.cookies = parseRequestCookies(h.cookies, kv.value) tmp := *kv copy(h.h[i:], h.h[i+1:]) @@ -1861,20 +2138,60 @@ value []byte err error + // hLen stores header subslice len + hLen int + disableNormalizing bool + + // by checking whether the next line contains a colon or not to tell + // it's a header entry or a multi line value of current header entry. + // the side effect of this operation is that we know the index of the + // next colon and new line, so this can be used during next iteration, + // instead of find them again. + nextColon int + nextNewLine int + + initialized bool } func (s *headerScanner) next() bool { + if !s.initialized { + s.nextColon = -1 + s.nextNewLine = -1 + s.initialized = true + } bLen := len(s.b) - if bLen >= 2 && s.b[0] == '\r' && s.b[1] == '\n' { + if bLen >= 2 && s.b[0] == rChar && s.b[1] == nChar { s.b = s.b[2:] + s.hLen += 2 return false } - if bLen >= 1 && s.b[0] == '\n' { + if bLen >= 1 && s.b[0] == nChar { s.b = s.b[1:] + s.hLen++ return false } - n := bytes.IndexByte(s.b, ':') + var n int + if s.nextColon >= 0 { + n = s.nextColon + s.nextColon = -1 + } else { + n = bytes.IndexByte(s.b, ':') + + // There can't be a \n inside the header name, check for this. + x := bytes.IndexByte(s.b, nChar) + if x < 0 { + // A header name should always at some point be followed by a \n + // even if it's the one that terminates the header block. + s.err = errNeedMore + return false + } + if x < n { + // There was a \n before the : + s.err = errInvalidName + return false + } + } if n < 0 { s.err = errNeedMore return false @@ -1884,23 +2201,65 @@ n++ for len(s.b) > n && s.b[n] == ' ' { n++ + // the newline index is a relative index, and lines below trimed `s.b` by `n`, + // so the relative newline index also shifted forward. it's safe to decrease + // to a minus value, it means it's invalid, and will find the newline again. + s.nextNewLine-- } + s.hLen += n s.b = s.b[n:] - n = bytes.IndexByte(s.b, '\n') + if s.nextNewLine >= 0 { + n = s.nextNewLine + s.nextNewLine = -1 + } else { + n = bytes.IndexByte(s.b, nChar) + } if n < 0 { s.err = errNeedMore return false } + isMultiLineValue := false + for { + if n+1 >= len(s.b) { + break + } + if s.b[n+1] != ' ' && s.b[n+1] != '\t' { + break + } + d := bytes.IndexByte(s.b[n+1:], nChar) + if d <= 0 { + break + } else if d == 1 && s.b[n+1] == rChar { + break + } + e := n + d + 1 + if c := bytes.IndexByte(s.b[n+1:e], ':'); c >= 0 { + s.nextColon = c + s.nextNewLine = d - c - 1 + break + } + isMultiLineValue = true + n = e + } + if n >= len(s.b) { + s.err = errNeedMore + return false + } + oldB := s.b s.value = s.b[:n] + s.hLen += n + 1 s.b = s.b[n+1:] - if n > 0 && s.value[n-1] == '\r' { + if n > 0 && s.value[n-1] == rChar { n-- } for n > 0 && s.value[n-1] == ' ' { n-- } s.value = s.value[:n] + if isMultiLineValue { + s.value, s.b, s.hLen = normalizeHeaderValue(s.value, oldB, s.hLen) + } return true } @@ -1939,7 +2298,7 @@ var vs headerValueScanner vs.b = s for vs.next() { - if bytes.Equal(vs.value, value) { + if caseInsensitiveCompare(vs.value, value) { return true } } @@ -1947,12 +2306,12 @@ } func nextLine(b []byte) ([]byte, []byte, error) { - nNext := bytes.IndexByte(b, '\n') + nNext := bytes.IndexByte(b, nChar) if nNext < 0 { return nil, nil, errNeedMore } n := nNext - if n > 0 && b[n-1] == '\r' { + if n > 0 && b[n-1] == rChar { n-- } return b[:n], b[nNext+1:], nil @@ -1960,7 +2319,9 @@ func initHeaderKV(kv *argsKV, key, value string, disableNormalizing bool) { kv.key = getHeaderKeyBytes(kv, key, disableNormalizing) + // https://tools.ietf.org/html/rfc7230#section-3.2.4 kv.value = append(kv.value[:0], value...) + kv.value = removeNewLines(kv.value) } func getHeaderKeyBytes(kv *argsKV, key string, disableNormalizing bool) []byte { @@ -1969,28 +2330,107 @@ return kv.key } +func normalizeHeaderValue(ov, ob []byte, headerLength int) (nv, nb []byte, nhl int) { + nv = ov + length := len(ov) + if length <= 0 { + return + } + write := 0 + shrunk := 0 + lineStart := false + for read := 0; read < length; read++ { + c := ov[read] + if c == rChar || c == nChar { + shrunk++ + if c == nChar { + lineStart = true + } + continue + } else if lineStart && c == '\t' { + c = ' ' + } else { + lineStart = false + } + nv[write] = c + write++ + } + + nv = nv[:write] + copy(ob[write:], ob[write+shrunk:]) + + // Check if we need to skip \r\n or just \n + skip := 0 + if ob[write] == rChar { + if ob[write+1] == nChar { + skip += 2 + } else { + skip++ + } + } else if ob[write] == nChar { + skip++ + } + + nb = ob[write+skip : len(ob)-shrunk] + nhl = headerLength - shrunk + return +} + func normalizeHeaderKey(b []byte, disableNormalizing bool) { if disableNormalizing { return } n := len(b) - up := true - for i := 0; i < n; i++ { - switch b[i] { - case '-': - up = true - default: - if up { - up = false - uppercaseByte(&b[i]) - } else { - lowercaseByte(&b[i]) + if n == 0 { + return + } + + b[0] = toUpperTable[b[0]] + for i := 1; i < n; i++ { + p := &b[i] + if *p == '-' { + i++ + if i < n { + b[i] = toUpperTable[b[i]] } + continue } + *p = toLowerTable[*p] } } +// removeNewLines will replace `\r` and `\n` with an empty space +func removeNewLines(raw []byte) []byte { + // check if a `\r` is present and save the position. + // if no `\r` is found, check if a `\n` is present. + foundR := bytes.IndexByte(raw, rChar) + foundN := bytes.IndexByte(raw, nChar) + start := 0 + + if foundN != -1 { + if foundR > foundN { + start = foundN + } else if foundR != -1 { + start = foundR + } + } else if foundR != -1 { + start = foundR + } else { + return raw + } + + for i := start; i < len(raw); i++ { + switch raw[i] { + case rChar, nChar: + raw[i] = ' ' + default: + continue + } + } + return raw +} + // AppendNormalizedHeaderKey appends normalized header key (name) to dst // and returns the resulting dst. // @@ -2021,7 +2461,26 @@ return AppendNormalizedHeaderKey(dst, b2s(key)) } -var errNeedMore = errors.New("need more data: cannot find trailing lf") +var ( + errNeedMore = errors.New("need more data: cannot find trailing lf") + errInvalidName = errors.New("invalid header name") + errSmallBuffer = errors.New("small read buffer. Increase ReadBufferSize") +) + +// ErrNothingRead is returned when a keep-alive connection is closed, +// either because the remote closed it or because of a read timeout. +type ErrNothingRead struct { + error +} + +// ErrSmallBuffer is returned when the provided buffer size is too small +// for reading request and/or response headers. +// +// ReadBufferSize value from Server or clients should reduce the number +// of such errors. +type ErrSmallBuffer struct { + error +} func mustPeekBuffered(r *bufio.Reader) []byte { buf, err := r.Peek(r.Buffered()) diff -Nru golang-github-valyala-fasthttp-20160617/header_regression_test.go golang-github-valyala-fasthttp-1.31.0/header_regression_test.go --- golang-github-valyala-fasthttp-20160617/header_regression_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/header_regression_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -9,6 +9,8 @@ ) func TestIssue28ResponseWithoutBodyNoContentType(t *testing.T) { + t.Parallel() + var r Response // Empty response without content-type @@ -41,10 +43,12 @@ } func TestIssue6RequestHeaderSetContentType(t *testing.T) { - testIssue6RequestHeaderSetContentType(t, "GET") - testIssue6RequestHeaderSetContentType(t, "POST") - testIssue6RequestHeaderSetContentType(t, "PUT") - testIssue6RequestHeaderSetContentType(t, "PATCH") + t.Parallel() + + testIssue6RequestHeaderSetContentType(t, MethodGet) + testIssue6RequestHeaderSetContentType(t, MethodPost) + testIssue6RequestHeaderSetContentType(t, MethodPut) + testIssue6RequestHeaderSetContentType(t, MethodPatch) } func testIssue6RequestHeaderSetContentType(t *testing.T, method string) { @@ -77,11 +81,7 @@ if string(h.Method()) != method { t.Fatalf("unexpected method: %q. Expecting %q", h.Method(), method) } - if method != "GET" { - if h.ContentLength() != contentLength { - t.Fatalf("unexpected content-length: %d. Expecting %d. method=%q", h.ContentLength(), contentLength, method) - } - } else if h.ContentLength() != 0 { - t.Fatalf("unexpected content-length for GET method: %d. Expecting 0", h.ContentLength()) + if h.ContentLength() != contentLength { + t.Fatalf("unexpected content-length: %d. Expecting %d. method=%q", h.ContentLength(), contentLength, method) } } diff -Nru golang-github-valyala-fasthttp-20160617/headers.go golang-github-valyala-fasthttp-1.31.0/headers.go --- golang-github-valyala-fasthttp-20160617/headers.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/headers.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,164 @@ +package fasthttp + +// Headers +const ( + // Authentication + HeaderAuthorization = "Authorization" + HeaderProxyAuthenticate = "Proxy-Authenticate" + HeaderProxyAuthorization = "Proxy-Authorization" + HeaderWWWAuthenticate = "WWW-Authenticate" + + // Caching + HeaderAge = "Age" + HeaderCacheControl = "Cache-Control" + HeaderClearSiteData = "Clear-Site-Data" + HeaderExpires = "Expires" + HeaderPragma = "Pragma" + HeaderWarning = "Warning" + + // Client hints + HeaderAcceptCH = "Accept-CH" + HeaderAcceptCHLifetime = "Accept-CH-Lifetime" + HeaderContentDPR = "Content-DPR" + HeaderDPR = "DPR" + HeaderEarlyData = "Early-Data" + HeaderSaveData = "Save-Data" + HeaderViewportWidth = "Viewport-Width" + HeaderWidth = "Width" + + // Conditionals + HeaderETag = "ETag" + HeaderIfMatch = "If-Match" + HeaderIfModifiedSince = "If-Modified-Since" + HeaderIfNoneMatch = "If-None-Match" + HeaderIfUnmodifiedSince = "If-Unmodified-Since" + HeaderLastModified = "Last-Modified" + HeaderVary = "Vary" + + // Connection management + HeaderConnection = "Connection" + HeaderKeepAlive = "Keep-Alive" + + // Content negotiation + HeaderAccept = "Accept" + HeaderAcceptCharset = "Accept-Charset" + HeaderAcceptEncoding = "Accept-Encoding" + HeaderAcceptLanguage = "Accept-Language" + + // Controls + HeaderCookie = "Cookie" + HeaderExpect = "Expect" + HeaderMaxForwards = "Max-Forwards" + HeaderSetCookie = "Set-Cookie" + + // CORS + HeaderAccessControlAllowCredentials = "Access-Control-Allow-Credentials" + HeaderAccessControlAllowHeaders = "Access-Control-Allow-Headers" + HeaderAccessControlAllowMethods = "Access-Control-Allow-Methods" + HeaderAccessControlAllowOrigin = "Access-Control-Allow-Origin" + HeaderAccessControlExposeHeaders = "Access-Control-Expose-Headers" + HeaderAccessControlMaxAge = "Access-Control-Max-Age" + HeaderAccessControlRequestHeaders = "Access-Control-Request-Headers" + HeaderAccessControlRequestMethod = "Access-Control-Request-Method" + HeaderOrigin = "Origin" + HeaderTimingAllowOrigin = "Timing-Allow-Origin" + HeaderXPermittedCrossDomainPolicies = "X-Permitted-Cross-Domain-Policies" + + // Do Not Track + HeaderDNT = "DNT" + HeaderTk = "Tk" + + // Downloads + HeaderContentDisposition = "Content-Disposition" + + // Message body information + HeaderContentEncoding = "Content-Encoding" + HeaderContentLanguage = "Content-Language" + HeaderContentLength = "Content-Length" + HeaderContentLocation = "Content-Location" + HeaderContentType = "Content-Type" + + // Proxies + HeaderForwarded = "Forwarded" + HeaderVia = "Via" + HeaderXForwardedFor = "X-Forwarded-For" + HeaderXForwardedHost = "X-Forwarded-Host" + HeaderXForwardedProto = "X-Forwarded-Proto" + + // Redirects + HeaderLocation = "Location" + + // Request context + HeaderFrom = "From" + HeaderHost = "Host" + HeaderReferer = "Referer" + HeaderReferrerPolicy = "Referrer-Policy" + HeaderUserAgent = "User-Agent" + + // Response context + HeaderAllow = "Allow" + HeaderServer = "Server" + + // Range requests + HeaderAcceptRanges = "Accept-Ranges" + HeaderContentRange = "Content-Range" + HeaderIfRange = "If-Range" + HeaderRange = "Range" + + // Security + HeaderContentSecurityPolicy = "Content-Security-Policy" + HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" + HeaderCrossOriginResourcePolicy = "Cross-Origin-Resource-Policy" + HeaderExpectCT = "Expect-CT" + HeaderFeaturePolicy = "Feature-Policy" + HeaderPublicKeyPins = "Public-Key-Pins" + HeaderPublicKeyPinsReportOnly = "Public-Key-Pins-Report-Only" + HeaderStrictTransportSecurity = "Strict-Transport-Security" + HeaderUpgradeInsecureRequests = "Upgrade-Insecure-Requests" + HeaderXContentTypeOptions = "X-Content-Type-Options" + HeaderXDownloadOptions = "X-Download-Options" + HeaderXFrameOptions = "X-Frame-Options" + HeaderXPoweredBy = "X-Powered-By" + HeaderXXSSProtection = "X-XSS-Protection" + + // Server-sent event + HeaderLastEventID = "Last-Event-ID" + HeaderNEL = "NEL" + HeaderPingFrom = "Ping-From" + HeaderPingTo = "Ping-To" + HeaderReportTo = "Report-To" + + // Transfer coding + HeaderTE = "TE" + HeaderTrailer = "Trailer" + HeaderTransferEncoding = "Transfer-Encoding" + + // WebSockets + HeaderSecWebSocketAccept = "Sec-WebSocket-Accept" + HeaderSecWebSocketExtensions = "Sec-WebSocket-Extensions" + HeaderSecWebSocketKey = "Sec-WebSocket-Key" + HeaderSecWebSocketProtocol = "Sec-WebSocket-Protocol" + HeaderSecWebSocketVersion = "Sec-WebSocket-Version" + + // Other + HeaderAcceptPatch = "Accept-Patch" + HeaderAcceptPushPolicy = "Accept-Push-Policy" + HeaderAcceptSignature = "Accept-Signature" + HeaderAltSvc = "Alt-Svc" + HeaderDate = "Date" + HeaderIndex = "Index" + HeaderLargeAllocation = "Large-Allocation" + HeaderLink = "Link" + HeaderPushPolicy = "Push-Policy" + HeaderRetryAfter = "Retry-After" + HeaderServerTiming = "Server-Timing" + HeaderSignature = "Signature" + HeaderSignedHeaders = "Signed-Headers" + HeaderSourceMap = "SourceMap" + HeaderUpgrade = "Upgrade" + HeaderXDNSPrefetchControl = "X-DNS-Prefetch-Control" + HeaderXPingback = "X-Pingback" + HeaderXRequestedWith = "X-Requested-With" + HeaderXRobotsTag = "X-Robots-Tag" + HeaderXUACompatible = "X-UA-Compatible" +) diff -Nru golang-github-valyala-fasthttp-20160617/header_test.go golang-github-valyala-fasthttp-1.31.0/header_test.go --- golang-github-valyala-fasthttp-20160617/header_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/header_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -3,14 +3,323 @@ import ( "bufio" "bytes" + "encoding/base64" "fmt" "io" "io/ioutil" + "net/http" + "reflect" "strings" "testing" ) +func TestResponseHeaderAddContentType(t *testing.T) { + t.Parallel() + + var h ResponseHeader + h.Add("Content-Type", "test") + + got := string(h.Peek("Content-Type")) + expected := "test" + if got != expected { + t.Errorf("expected %q got %q", expected, got) + } + + var buf bytes.Buffer + h.WriteTo(&buf) //nolint:errcheck + + if n := strings.Count(buf.String(), "Content-Type: "); n != 1 { + t.Errorf("Content-Type occurred %d times", n) + } +} + +func TestResponseHeaderMultiLineValue(t *testing.T) { + t.Parallel() + + s := "HTTP/1.1 200 OK\r\n" + + "EmptyValue1:\r\n" + + "Content-Type: foo/bar;\r\n\tnewline;\r\n another/newline\r\n" + + "Foo: Bar\r\n" + + "Multi-Line: one;\r\n two\r\n" + + "Values: v1;\r\n v2; v3;\r\n v4;\tv5\r\n" + + "\r\n" + header := new(ResponseHeader) + if _, err := header.parse([]byte(s)); err != nil { + t.Fatalf("parse headers with multi-line values failed, %s", err) + } + response, err := http.ReadResponse(bufio.NewReader(strings.NewReader(s)), nil) + if err != nil { + t.Fatalf("parse response using net/http failed, %s", err) + } + + for name, vals := range response.Header { + got := string(header.Peek(name)) + want := vals[0] + + if got != want { + t.Errorf("unexpected %s got: %q want: %q", name, got, want) + } + } +} + +func TestResponseHeaderMultiLineName(t *testing.T) { + t.Parallel() + + s := "HTTP/1.1 200 OK\r\n" + + "Host: golang.org\r\n" + + "Gopher-New-\r\n" + + " Line: This is a header on multiple lines\r\n" + + "\r\n" + header := new(ResponseHeader) + if _, err := header.parse([]byte(s)); err != errInvalidName { + m := make(map[string]string) + header.VisitAll(func(key, value []byte) { + m[string(key)] = string(value) + }) + t.Errorf("expected error, got %q (%v)", m, err) + } +} + +func TestResponseHeaderMultiLinePaniced(t *testing.T) { + t.Parallel() + + // Input generated by fuzz testing that caused the parser to panic. + s, _ := base64.StdEncoding.DecodeString("aAEAIDoKKDoKICA6CgkKCiA6CiA6CgkpCiA6CiA6CiA6Cig6CiAgOgoJCgogOgogOgoJKQogOgogOgogOgogOgogOgoJOg86CiA6CiA6Cig6CiAyCg==") + header := new(RequestHeader) + header.parse(s) //nolint:errcheck +} + +func TestResponseHeaderEmptyValueFromHeader(t *testing.T) { + t.Parallel() + + var h1 ResponseHeader + h1.SetContentType("foo/bar") + h1.Set("EmptyValue1", "") + h1.Set("EmptyValue2", " ") + s := h1.String() + + var h ResponseHeader + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(h.ContentType()) != string(h1.ContentType()) { + t.Fatalf("unexpected content-type: %q. Expecting %q", h.ContentType(), h1.ContentType()) + } + v1 := h.Peek("EmptyValue1") + if len(v1) > 0 { + t.Fatalf("expecting empty value. Got %q", v1) + } + v2 := h.Peek("EmptyValue2") + if len(v2) > 0 { + t.Fatalf("expecting empty value. Got %q", v2) + } +} + +func TestResponseHeaderEmptyValueFromString(t *testing.T) { + t.Parallel() + + s := "HTTP/1.1 200 OK\r\n" + + "EmptyValue1:\r\n" + + "Content-Type: foo/bar\r\n" + + "EmptyValue2: \r\n" + + "\r\n" + + var h ResponseHeader + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(h.ContentType()) != "foo/bar" { + t.Fatalf("unexpected content-type: %q. Expecting %q", h.ContentType(), "foo/bar") + } + v1 := h.Peek("EmptyValue1") + if len(v1) > 0 { + t.Fatalf("expecting empty value. Got %q", v1) + } + v2 := h.Peek("EmptyValue2") + if len(v2) > 0 { + t.Fatalf("expecting empty value. Got %q", v2) + } +} + +func TestRequestHeaderEmptyValueFromHeader(t *testing.T) { + t.Parallel() + + var h1 RequestHeader + h1.SetRequestURI("/foo/bar") + h1.SetHost("foobar") + h1.Set("EmptyValue1", "") + h1.Set("EmptyValue2", " ") + s := h1.String() + + var h RequestHeader + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(h.Host()) != string(h1.Host()) { + t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), h1.Host()) + } + v1 := h.Peek("EmptyValue1") + if len(v1) > 0 { + t.Fatalf("expecting empty value. Got %q", v1) + } + v2 := h.Peek("EmptyValue2") + if len(v2) > 0 { + t.Fatalf("expecting empty value. Got %q", v2) + } +} + +func TestRequestHeaderEmptyValueFromString(t *testing.T) { + t.Parallel() + + s := "GET / HTTP/1.1\r\n" + + "EmptyValue1:\r\n" + + "Host: foobar\r\n" + + "EmptyValue2: \r\n" + + "\r\n" + var h RequestHeader + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(h.Host()) != "foobar" { + t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar") + } + v1 := h.Peek("EmptyValue1") + if len(v1) > 0 { + t.Fatalf("expecting empty value. Got %q", v1) + } + v2 := h.Peek("EmptyValue2") + if len(v2) > 0 { + t.Fatalf("expecting empty value. Got %q", v2) + } +} + +func TestRequestRawHeaders(t *testing.T) { + t.Parallel() + + kvs := "hOsT: foobar\r\n" + + "value: b\r\n" + + "\r\n" + t.Run("normalized", func(t *testing.T) { + s := "GET / HTTP/1.1\r\n" + kvs + exp := kvs + var h RequestHeader + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(h.Host()) != "foobar" { + t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar") + } + v2 := h.Peek("Value") + if !bytes.Equal(v2, []byte{'b'}) { + t.Fatalf("expecting non empty value. Got %q", v2) + } + if raw := h.RawHeaders(); string(raw) != exp { + t.Fatalf("expected header %q, got %q", exp, raw) + } + }) + for _, n := range []int{0, 1, 4, 8} { + t.Run(fmt.Sprintf("post-%dk", n), func(t *testing.T) { + l := 1024 * n + body := make([]byte, l) + for i := range body { + body[i] = 'a' + } + cl := fmt.Sprintf("Content-Length: %d\r\n", l) + s := "POST / HTTP/1.1\r\n" + cl + kvs + string(body) + exp := cl + kvs + var h RequestHeader + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(h.Host()) != "foobar" { + t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar") + } + v2 := h.Peek("Value") + if !bytes.Equal(v2, []byte{'b'}) { + t.Fatalf("expecting non empty value. Got %q", v2) + } + if raw := h.RawHeaders(); string(raw) != exp { + t.Fatalf("expected header %q, got %q", exp, raw) + } + }) + } + t.Run("http10", func(t *testing.T) { + s := "GET / HTTP/1.0\r\n" + kvs + exp := kvs + var h RequestHeader + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(h.Host()) != "foobar" { + t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "foobar") + } + v2 := h.Peek("Value") + if !bytes.Equal(v2, []byte{'b'}) { + t.Fatalf("expecting non empty value. Got %q", v2) + } + if raw := h.RawHeaders(); string(raw) != exp { + t.Fatalf("expected header %q, got %q", exp, raw) + } + }) + t.Run("no-kvs", func(t *testing.T) { + s := "GET / HTTP/1.1\r\n\r\n" + exp := "" + var h RequestHeader + h.DisableNormalizing() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(h.Host()) != "" { + t.Fatalf("unexpected host: %q. Expecting %q", h.Host(), "") + } + v1 := h.Peek("NoKey") + if len(v1) > 0 { + t.Fatalf("expecting empty value. Got %q", v1) + } + if raw := h.RawHeaders(); string(raw) != exp { + t.Fatalf("expected header %q, got %q", exp, raw) + } + }) +} + +func TestRequestHeaderSetCookieWithSpecialChars(t *testing.T) { + t.Parallel() + + var h RequestHeader + h.Set("Cookie", "ID&14") + s := h.String() + + if !strings.Contains(s, "Cookie: ID&14") { + t.Fatalf("Missing cookie in request header: [%s]", s) + } + + var h1 RequestHeader + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := h1.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + cookie := h1.Peek(HeaderCookie) + if string(cookie) != "ID&14" { + t.Fatalf("unexpected cooke: %q. Expecting %q", cookie, "ID&14") + } + + cookie = h1.Cookie("") + if string(cookie) != "ID&14" { + t.Fatalf("unexpected cooke: %q. Expecting %q", cookie, "ID&14") + } +} + func TestResponseHeaderDefaultStatusCode(t *testing.T) { + t.Parallel() + var h ResponseHeader statusCode := h.StatusCode() if statusCode != StatusOK { @@ -19,6 +328,8 @@ } func TestResponseHeaderDelClientCookie(t *testing.T) { + t.Parallel() + cookieName := "foobar" var h ResponseHeader @@ -41,10 +352,14 @@ } func TestResponseHeaderAdd(t *testing.T) { + t.Parallel() + m := make(map[string]struct{}) var h ResponseHeader h.Add("aaa", "bbb") + h.Add("content-type", "xxx") m["bbb"] = struct{}{} + m["xxx"] = struct{}{} for i := 0; i < 10; i++ { v := fmt.Sprintf("%d", i) h.Add("Foo-Bar", v) @@ -56,12 +371,11 @@ h.VisitAll(func(k, v []byte) { switch string(k) { - case "Aaa", "Foo-Bar": + case "Aaa", "Foo-Bar", "Content-Type": if _, ok := m[string(v)]; !ok { t.Fatalf("unexpected value found %q. key %q", v, k) } delete(m, string(v)) - case "Content-Type": default: t.Fatalf("unexpected key found: %q", k) } @@ -79,35 +393,38 @@ h.VisitAll(func(k, v []byte) { switch string(k) { - case "Aaa", "Foo-Bar": + case "Aaa", "Foo-Bar", "Content-Type": m[string(v)] = struct{}{} - case "Content-Type": default: t.Fatalf("unexpected key found: %q", k) } }) - if len(m) != 11 { - t.Fatalf("unexpected number of headers: %d. Expecting 11", len(m)) + if len(m) != 12 { + t.Fatalf("unexpected number of headers: %d. Expecting 12", len(m)) } } func TestRequestHeaderAdd(t *testing.T) { + t.Parallel() + m := make(map[string]struct{}) var h RequestHeader h.Add("aaa", "bbb") + h.Add("user-agent", "xxx") m["bbb"] = struct{}{} + m["xxx"] = struct{}{} for i := 0; i < 10; i++ { v := fmt.Sprintf("%d", i) h.Add("Foo-Bar", v) m[v] = struct{}{} } - if h.Len() != 11 { - t.Fatalf("unexpected header len %d. Expecting 11", h.Len()) + if h.Len() != 12 { + t.Fatalf("unexpected header len %d. Expecting 12", h.Len()) } h.VisitAll(func(k, v []byte) { switch string(k) { - case "Aaa", "Foo-Bar": + case "Aaa", "Foo-Bar", "User-Agent": if _, ok := m[string(v)]; !ok { t.Fatalf("unexpected value found %q. key %q", v, k) } @@ -129,15 +446,14 @@ h.VisitAll(func(k, v []byte) { switch string(k) { - case "Aaa", "Foo-Bar": + case "Aaa", "Foo-Bar", "User-Agent": m[string(v)] = struct{}{} - case "User-Agent": default: t.Fatalf("unexpected key found: %q", k) } }) - if len(m) != 11 { - t.Fatalf("unexpected number of headers: %d. Expecting 11", len(m)) + if len(m) != 12 { + t.Fatalf("unexpected number of headers: %d. Expecting 12", len(m)) } s1 := h1.String() if s != s1 { @@ -146,6 +462,8 @@ } func TestHasHeaderValue(t *testing.T) { + t.Parallel() + testHasHeaderValue(t, "foobar", "foobar", true) testHasHeaderValue(t, "foobar", "foo", false) testHasHeaderValue(t, "foobar", "bar", false) @@ -169,12 +487,14 @@ } func TestRequestHeaderDel(t *testing.T) { + t.Parallel() + var h RequestHeader h.Set("Foo-Bar", "baz") h.Set("aaa", "bbb") - h.Set("Connection", "keep-alive") + h.Set(HeaderConnection, "keep-alive") h.Set("Content-Type", "aaa") - h.Set("Host", "aaabbb") + h.Set(HeaderHost, "aaabbb") h.Set("User-Agent", "asdfas") h.Set("Content-Length", "1123") h.Set("Cookie", "foobar=baz") @@ -195,27 +515,27 @@ if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } - hv = h.Peek("Connection") + hv = h.Peek(HeaderConnection) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } - hv = h.Peek("Content-Type") + hv = h.Peek(HeaderContentType) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } - hv = h.Peek("Host") + hv = h.Peek(HeaderHost) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } - hv = h.Peek("User-Agent") + hv = h.Peek(HeaderUserAgent) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } - hv = h.Peek("Content-Length") + hv = h.Peek(HeaderContentLength) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } - hv = h.Peek("Cookie") + hv = h.Peek(HeaderCookie) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } @@ -230,13 +550,15 @@ } func TestResponseHeaderDel(t *testing.T) { + t.Parallel() + var h ResponseHeader h.Set("Foo-Bar", "baz") h.Set("aaa", "bbb") - h.Set("Connection", "keep-alive") - h.Set("Content-Type", "aaa") - h.Set("Server", "aaabbb") - h.Set("Content-Length", "1123") + h.Set(HeaderConnection, "keep-alive") + h.Set(HeaderContentType, "aaa") + h.Set(HeaderServer, "aaabbb") + h.Set(HeaderContentLength, "1123") var c Cookie c.SetKey("foo") @@ -246,7 +568,7 @@ h.Del("foo-bar") h.Del("connection") h.DelBytes([]byte("content-type")) - h.Del("Server") + h.Del(HeaderServer) h.Del("content-length") h.Del("set-cookie") @@ -258,19 +580,19 @@ if len(hv) > 0 { t.Fatalf("non-zero header value: %q", hv) } - hv = h.Peek("Connection") + hv = h.Peek(HeaderConnection) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } - hv = h.Peek("Content-Type") + hv = h.Peek(HeaderContentType) if string(hv) != string(defaultContentType) { t.Fatalf("unexpected content-type: %q. Expecting %q", hv, defaultContentType) } - hv = h.Peek("Server") + hv = h.Peek(HeaderServer) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } - hv = h.Peek("Content-Length") + hv = h.Peek(HeaderContentLength) if len(hv) > 0 { t.Fatalf("non-zero value: %q", hv) } @@ -284,6 +606,8 @@ } func TestAppendNormalizedHeaderKeyBytes(t *testing.T) { + t.Parallel() + testAppendNormalizedHeaderKeyBytes(t, "", "") testAppendNormalizedHeaderKeyBytes(t, "Content-Type", "Content-Type") testAppendNormalizedHeaderKeyBytes(t, "foO-bAr-BAZ", "Foo-Bar-Baz") @@ -299,6 +623,8 @@ } func TestRequestHeaderHTTP10ConnectionClose(t *testing.T) { + t.Parallel() + s := "GET / HTTP/1.0\r\nHost: foobar\r\n\r\n" var h RequestHeader br := bufio.NewReader(bytes.NewBufferString(s)) @@ -306,15 +632,14 @@ t.Fatalf("unexpected error: %s", err) } - if !h.connectionCloseFast() { - t.Fatalf("expecting 'Connection: close' request header") - } if !h.ConnectionClose() { t.Fatalf("expecting 'Connection: close' request header") } } func TestRequestHeaderHTTP10ConnectionKeepAlive(t *testing.T) { + t.Parallel() + s := "GET / HTTP/1.0\r\nHost: foobar\r\nConnection: keep-alive\r\n\r\n" var h RequestHeader br := bufio.NewReader(bytes.NewBufferString(s)) @@ -327,36 +652,41 @@ } } -func TestBufferStartEnd(t *testing.T) { - testBufferStartEnd(t, "", "", "") - testBufferStartEnd(t, "foobar", "foobar", "") +func TestBufferSnippet(t *testing.T) { + t.Parallel() + + testBufferSnippet(t, "", `""`) + testBufferSnippet(t, "foobar", `"foobar"`) b := string(createFixedBody(199)) - testBufferStartEnd(t, b, b, "") + bExpected := fmt.Sprintf("%q", b) + testBufferSnippet(t, b, bExpected) for i := 0; i < 10; i++ { b += "foobar" - testBufferStartEnd(t, b, b, "") + bExpected = fmt.Sprintf("%q", b) + testBufferSnippet(t, b, bExpected) } b = string(createFixedBody(400)) - testBufferStartEnd(t, b, b, "") + bExpected = fmt.Sprintf("%q", b) + testBufferSnippet(t, b, bExpected) for i := 0; i < 10; i++ { b += "sadfqwer" - testBufferStartEnd(t, b, b[:200], b[len(b)-200:]) + bExpected = fmt.Sprintf("%q...%q", b[:200], b[len(b)-200:]) + testBufferSnippet(t, b, bExpected) } } -func testBufferStartEnd(t *testing.T, buf, expectedStart, expectedEnd string) { - start, end := bufferStartEnd([]byte(buf)) - if string(start) != expectedStart { - t.Fatalf("unexpected start %q. Expecting %q. buf %q", start, expectedStart, buf) - } - if string(end) != expectedEnd { - t.Fatalf("unexpected end %q. Expecting %q. buf %q", end, expectedEnd, buf) +func testBufferSnippet(t *testing.T, buf, expectedSnippet string) { + snippet := bufferSnippet([]byte(buf)) + if snippet != expectedSnippet { + t.Fatalf("unexpected snippet %s. Expecting %s", snippet, expectedSnippet) } } func TestResponseHeaderTrailingCRLFSuccess(t *testing.T) { + t.Parallel() + trailingCRLF := "\r\n\r\n\r\n" s := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\n\r\n" + trailingCRLF @@ -377,6 +707,8 @@ } func TestResponseHeaderTrailingCRLFError(t *testing.T) { + t.Parallel() + trailingCRLF := "\r\nerror\r\n\r\n" s := "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\n\r\n" + trailingCRLF @@ -397,6 +729,8 @@ } func TestRequestHeaderTrailingCRLFSuccess(t *testing.T) { + t.Parallel() + trailingCRLF := "\r\n\r\n\r\n" s := "GET / HTTP/1.1\r\nHost: aaa.com\r\n\r\n" + trailingCRLF @@ -417,6 +751,8 @@ } func TestRequestHeaderTrailingCRLFError(t *testing.T) { + t.Parallel() + trailingCRLF := "\r\nerror\r\n\r\n" s := "GET / HTTP/1.1\r\nHost: aaa.com\r\n\r\n" + trailingCRLF @@ -437,6 +773,8 @@ } func TestRequestHeaderReadEOF(t *testing.T) { + t.Parallel() + var r RequestHeader br := bufio.NewReader(&bytes.Buffer{}) @@ -460,6 +798,8 @@ } func TestResponseHeaderReadEOF(t *testing.T) { + t.Parallel() + var r ResponseHeader br := bufio.NewReader(&bytes.Buffer{}) @@ -483,6 +823,8 @@ } func TestResponseHeaderOldVersion(t *testing.T) { + t.Parallel() + var h ResponseHeader s := "HTTP/1.0 200 OK\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345" @@ -504,6 +846,8 @@ } func TestRequestHeaderSetByteRange(t *testing.T) { + t.Parallel() + testRequestHeaderSetByteRange(t, 0, 10, "bytes=0-10") testRequestHeaderSetByteRange(t, 123, -1, "bytes=123-") testRequestHeaderSetByteRange(t, -234, 58349, "bytes=-234") @@ -512,13 +856,15 @@ func testRequestHeaderSetByteRange(t *testing.T, startPos, endPos int, expectedV string) { var h RequestHeader h.SetByteRange(startPos, endPos) - v := h.Peek("Range") + v := h.Peek(HeaderRange) if string(v) != expectedV { t.Fatalf("unexpected range: %q. Expecting %q. startPos=%d, endPos=%d", v, expectedV, startPos, endPos) } } func TestResponseHeaderSetContentRange(t *testing.T) { + t.Parallel() + testResponseHeaderSetContentRange(t, 0, 0, 1, "bytes 0-0/1") testResponseHeaderSetContentRange(t, 123, 456, 789, "bytes 123-456/789") } @@ -526,7 +872,7 @@ func testResponseHeaderSetContentRange(t *testing.T, startPos, endPos, contentLength int, expectedV string) { var h ResponseHeader h.SetContentRange(startPos, endPos, contentLength) - v := h.Peek("Content-Range") + v := h.Peek(HeaderContentRange) if string(v) != expectedV { t.Fatalf("unexpected content-range: %q. Expecting %q. startPos=%d, endPos=%d, contentLength=%d", v, expectedV, startPos, endPos, contentLength) @@ -534,6 +880,8 @@ } func TestRequestHeaderHasAcceptEncoding(t *testing.T) { + t.Parallel() + testRequestHeaderHasAcceptEncoding(t, "", "gzip", false) testRequestHeaderHasAcceptEncoding(t, "gzip", "sdhc", false) testRequestHeaderHasAcceptEncoding(t, "deflate", "deflate", true) @@ -551,7 +899,7 @@ func testRequestHeaderHasAcceptEncoding(t *testing.T, ae, v string, resultExpected bool) { var h RequestHeader - h.Set("Accept-Encoding", ae) + h.Set(HeaderAcceptEncoding, ae) result := h.HasAcceptEncoding(v) if result != resultExpected { t.Fatalf("unexpected result in HasAcceptEncoding(%q, %q): %v. Expecting %v", ae, v, result, resultExpected) @@ -559,6 +907,8 @@ } func TestRequestMultipartFormBoundary(t *testing.T) { + t.Parallel() + testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=foobar\r\n\r\n", "foobar") // incorrect content-type @@ -573,6 +923,9 @@ // boundary after other content-type params testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; foo=bar; boundary=--aaabb \r\n\r\n", "--aaabb") + // quoted boundary + testRequestMultipartFormBoundary(t, "POST / HTTP/1.1\r\nContent-Type: multipart/form-data; boundary=\"foobar\"\r\n\r\n", "foobar") + var h RequestHeader h.SetMultipartFormBoundary("foobarbaz") b := h.MultipartFormBoundary() @@ -596,6 +949,8 @@ } func TestResponseHeaderConnectionUpgrade(t *testing.T) { + t.Parallel() + testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nConnection: Upgrade, HTTP2-Settings\r\n\r\n", true, true) testResponseHeaderConnectionUpgrade(t, "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nConnection: keep-alive, Upgrade\r\n\r\n", @@ -637,6 +992,8 @@ } func TestRequestHeaderConnectionUpgrade(t *testing.T) { + t.Parallel() + testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: Upgrade, HTTP2-Settings\r\nHost: foobar.com\r\n\r\n", true, true) testRequestHeaderConnectionUpgrade(t, "GET /foobar HTTP/1.1\r\nConnection: keep-alive,Upgrade\r\nHost: foobar.com\r\n\r\n", @@ -682,6 +1039,8 @@ } func TestRequestHeaderProxyWithCookie(t *testing.T) { + t.Parallel() + // Proxy request header (read it, then write it without touching any headers). var h RequestHeader r := bytes.NewBufferString("GET /foo HTTP/1.1\r\nFoo: bar\r\nHost: aaa.com\r\nCookie: foo=bar; bazzz=aaaaaaa; x=y\r\nCookie: aqqqqq=123\r\n\r\n") @@ -726,49 +1085,10 @@ } } -func TestPeekRawHeader(t *testing.T) { - // empty header - testPeekRawHeader(t, "", "Foo-Bar", "") - - // different case - testPeekRawHeader(t, "Content-Length: 3443\r\n", "content-length", "") - - // no trailing crlf - testPeekRawHeader(t, "Content-Length: 234", "Content-Length", "") - - // single header - testPeekRawHeader(t, "Content-Length: 12345\r\n", "Content-Length", "12345") - - // multiple headers - testPeekRawHeader(t, "Host: foobar\r\nContent-Length: 434\r\nFoo: bar\r\n\r\n", "Content-Length", "434") - - // lf without cr - testPeekRawHeader(t, "Foo: bar\nConnection: close\nAaa: bbb\ncc: ddd\n", "Connection", "close") -} - -func testPeekRawHeader(t *testing.T, rawHeaders, key string, expectedValue string) { - v := peekRawHeader([]byte(rawHeaders), []byte(key)) - if string(v) != expectedValue { - t.Fatalf("unexpected raw headers value %q. Expected %q. key %q, rawHeaders %q", v, expectedValue, key, rawHeaders) - } -} - func TestResponseHeaderFirstByteReadEOF(t *testing.T) { - var h ResponseHeader + t.Parallel() - r := &errorReader{fmt.Errorf("non-eof error")} - br := bufio.NewReader(r) - err := h.Read(br) - if err == nil { - t.Fatalf("expecting error") - } - if err != io.EOF { - t.Fatalf("unexpected error %s. Expecting %s", err, io.EOF) - } -} - -func TestRequestHeaderFirstByteReadEOF(t *testing.T) { - var h RequestHeader + var h ResponseHeader r := &errorReader{fmt.Errorf("non-eof error")} br := bufio.NewReader(r) @@ -790,23 +1110,18 @@ } func TestRequestHeaderEmptyMethod(t *testing.T) { + t.Parallel() + var h RequestHeader if !h.IsGet() { t.Fatalf("empty method must be equivalent to GET") } - if h.IsPost() { - t.Fatalf("empty method cannot be POST") - } - if h.IsHead() { - t.Fatalf("empty method cannot be HEAD") - } - if h.IsDelete() { - t.Fatalf("empty method cannot be DELETE") - } } func TestResponseHeaderHTTPVer(t *testing.T) { + t.Parallel() + // non-http/1.1 testResponseHeaderHTTPVer(t, "HTTP/1.0 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", true) testResponseHeaderHTTPVer(t, "HTTP/0.9 200 OK\r\nContent-Type: aaa\r\nContent-Length: 123\r\n\r\n", true) @@ -817,6 +1132,8 @@ } func TestRequestHeaderHTTPVer(t *testing.T) { + t.Parallel() + // non-http/1.1 testRequestHeaderHTTPVer(t, "GET / HTTP/1.0\r\nHost: aa.com\r\n\r\n", true) testRequestHeaderHTTPVer(t, "GET / HTTP/0.9\r\nHost: aa.com\r\n\r\n", true) @@ -857,10 +1174,12 @@ } func TestResponseHeaderCopyTo(t *testing.T) { + t.Parallel() + var h ResponseHeader - h.Set("Set-Cookie", "foo=bar") - h.Set("Content-Type", "foobar") + h.Set(HeaderSetCookie, "foo=bar") + h.Set(HeaderContentType, "foobar") h.Set("AAA-BBB", "aaaa") var h1 ResponseHeader @@ -868,28 +1187,38 @@ if !bytes.Equal(h1.Peek("Set-cookie"), h.Peek("Set-Cookie")) { t.Fatalf("unexpected cookie %q. Expected %q", h1.Peek("set-cookie"), h.Peek("set-cookie")) } - if !bytes.Equal(h1.Peek("Content-Type"), h.Peek("Content-Type")) { + if !bytes.Equal(h1.Peek(HeaderContentType), h.Peek(HeaderContentType)) { t.Fatalf("unexpected content-type %q. Expected %q", h1.Peek("content-type"), h.Peek("content-type")) } if !bytes.Equal(h1.Peek("aaa-bbb"), h.Peek("AAA-BBB")) { t.Fatalf("unexpected aaa-bbb %q. Expected %q", h1.Peek("aaa-bbb"), h.Peek("aaa-bbb")) } + + // flush buf + h.bufKV = argsKV{} + h1.bufKV = argsKV{} + + if !reflect.DeepEqual(h, h1) { //nolint:govet + t.Fatalf("ResponseHeaderCopyTo fail, src: \n%+v\ndst: \n%+v\n", h, h1) //nolint:govet + } } func TestRequestHeaderCopyTo(t *testing.T) { + t.Parallel() + var h RequestHeader - h.Set("Cookie", "aa=bb; cc=dd") - h.Set("Content-Type", "foobar") - h.Set("Host", "aaaa") + h.Set(HeaderCookie, "aa=bb; cc=dd") + h.Set(HeaderContentType, "foobar") + h.Set(HeaderHost, "aaaa") h.Set("aaaxxx", "123") var h1 RequestHeader h.CopyTo(&h1) - if !bytes.Equal(h1.Peek("cookie"), h.Peek("Cookie")) { + if !bytes.Equal(h1.Peek("cookie"), h.Peek(HeaderCookie)) { t.Fatalf("unexpected cookie after copying: %q. Expected %q", h1.Peek("cookie"), h.Peek("cookie")) } - if !bytes.Equal(h1.Peek("content-type"), h.Peek("Content-Type")) { + if !bytes.Equal(h1.Peek("content-type"), h.Peek(HeaderContentType)) { t.Fatalf("unexpected content-type %q. Expected %q", h1.Peek("content-type"), h.Peek("content-type")) } if !bytes.Equal(h1.Peek("host"), h.Peek("host")) { @@ -898,13 +1227,79 @@ if !bytes.Equal(h1.Peek("aaaxxx"), h.Peek("aaaxxx")) { t.Fatalf("unexpected aaaxxx %q. Expected %q", h1.Peek("aaaxxx"), h.Peek("aaaxxx")) } + + // flush buf + h.bufKV = argsKV{} + h1.bufKV = argsKV{} + + if !reflect.DeepEqual(h, h1) { //nolint:govet + t.Fatalf("RequestHeaderCopyTo fail, src: \n%+v\ndst: \n%+v\n", h, h1) //nolint:govet + } +} + +func TestResponseContentTypeNoDefaultNotEmpty(t *testing.T) { + t.Parallel() + + var h ResponseHeader + + h.SetNoDefaultContentType(true) + h.SetContentLength(5) + + headers := h.String() + + if strings.Contains(headers, "Content-Type: \r\n") { + t.Fatalf("ResponseContentTypeNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", h, headers) //nolint:govet + } +} + +func TestRequestContentTypeDefaultNotEmpty(t *testing.T) { + t.Parallel() + + var h RequestHeader + h.SetMethod(MethodPost) + h.SetContentLength(5) + + w := &bytes.Buffer{} + bw := bufio.NewWriter(w) + if err := h.Write(bw); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + if err := bw.Flush(); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + var h1 RequestHeader + br := bufio.NewReader(w) + if err := h1.Read(br); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if string(h1.contentType) != "application/octet-stream" { + t.Fatalf("unexpected Content-Type %q. Expecting %q", h1.contentType, "application/octet-stream") + } +} + +func TestResponseDateNoDefaultNotEmpty(t *testing.T) { + t.Parallel() + + var h ResponseHeader + + h.noDefaultDate = true + + headers := h.String() + + if strings.Contains(headers, "\r\nDate: ") { + t.Fatalf("ResponseDateNoDefaultNotEmpty fail, response: \n%+v\noutcome: \n%q\n", h, headers) //nolint:govet + } } func TestRequestHeaderConnectionClose(t *testing.T) { + t.Parallel() + var h RequestHeader - h.Set("Connection", "close") - h.Set("Host", "foobar") + h.Set(HeaderConnection, "close") + h.Set(HeaderHost, "foobar") if !h.ConnectionClose() { t.Fatalf("connection: close not set") } @@ -927,12 +1322,15 @@ if !h1.ConnectionClose() { t.Fatalf("unexpected connection: close value: %v", h1.ConnectionClose()) } - if string(h1.Peek("Connection")) != "close" { + if string(h1.Peek(HeaderConnection)) != "close" { t.Fatalf("unexpected connection value: %q. Expecting %q", h.Peek("Connection"), "close") } + } func TestRequestHeaderSetCookie(t *testing.T) { + t.Parallel() + var h RequestHeader h.Set("Cookie", "foo=bar; baz=aaa") @@ -950,10 +1348,12 @@ } func TestResponseHeaderSetCookie(t *testing.T) { + t.Parallel() + var h ResponseHeader h.Set("set-cookie", "foo=bar; path=/aa/bb; domain=aaa.com") - h.Set("Set-Cookie", "aaaaa=bxx") + h.Set(HeaderSetCookie, "aaaaa=bxx") var c Cookie c.SetKey("foo") @@ -980,12 +1380,14 @@ } func TestResponseHeaderVisitAll(t *testing.T) { + t.Parallel() + var h ResponseHeader r := bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 123\r\nSet-Cookie: aa=bb; path=/foo/bar\r\nSet-Cookie: ccc\r\n\r\n") br := bufio.NewReader(r) if err := h.Read(br); err != nil { - t.Fatalf("Unepxected error: %s", err) + t.Fatalf("Unexpected error: %s", err) } if h.Len() != 4 { @@ -998,17 +1400,17 @@ k := string(key) v := string(value) switch k { - case "Content-Length": + case HeaderContentLength: if v != string(h.Peek(k)) { t.Fatalf("unexpected content-length: %q. Expecting %q", v, h.Peek(k)) } contentLengthCount++ - case "Content-Type": + case HeaderContentType: if v != string(h.Peek(k)) { t.Fatalf("Unexpected content-type: %q. Expected %q", v, h.Peek(k)) } contentTypeCount++ - case "Set-Cookie": + case HeaderSetCookie: if cookieCount == 0 && v != "aa=bb; path=/foo/bar" { t.Fatalf("unexpected cookie header: %q. Expected %q", v, "aa=bb; path=/foo/bar") } @@ -1032,6 +1434,8 @@ } func TestRequestHeaderVisitAll(t *testing.T) { + t.Parallel() + var h RequestHeader r := bytes.NewBufferString("GET / HTTP/1.1\r\nHost: aa.com\r\nXX: YYY\r\nXX: ZZ\r\nCookie: a=b; c=d\r\n\r\n") @@ -1050,7 +1454,7 @@ k := string(key) v := string(value) switch k { - case "Host": + case HeaderHost: if v != string(h.Peek(k)) { t.Fatalf("Unexpected host value %q. Expected %q", v, h.Peek(k)) } @@ -1063,17 +1467,17 @@ t.Fatalf("Unexpected value %q. Expected %q", v, "ZZ") } xxCount++ - case "Cookie": + case HeaderCookie: if v != "a=b; c=d" { t.Fatalf("Unexpected cookie %q. Expected %q", v, "a=b; c=d") } cookieCount++ default: - t.Fatalf("Unepxected header %q=%q", k, v) + t.Fatalf("Unexpected header %q=%q", k, v) } }) if hostCount != 1 { - t.Fatalf("Unepxected number of host headers detected %d. Expected 1", hostCount) + t.Fatalf("Unexpected number of host headers detected %d. Expected 1", hostCount) } if xxCount != 2 { t.Fatalf("Unexpected number of xx headers detected %d. Expected 2", xxCount) @@ -1083,7 +1487,52 @@ } } +func TestResponseHeaderVisitAllInOrder(t *testing.T) { + t.Parallel() + + var h RequestHeader + + r := bytes.NewBufferString("GET / HTTP/1.1\r\nContent-Type: aa\r\nCookie: a=b\r\nHost: example.com\r\nUser-Agent: xxx\r\n\r\n") + br := bufio.NewReader(r) + if err := h.Read(br); err != nil { + t.Fatalf("Unexpected error: %s", err) + } + + if h.Len() != 4 { + t.Fatalf("Unexpected number of headers: %d. Expected 4", h.Len()) + } + + order := []string{ + HeaderContentType, + HeaderCookie, + HeaderHost, + HeaderUserAgent, + } + values := []string{ + "aa", + "a=b", + "example.com", + "xxx", + } + + h.VisitAllInOrder(func(key, value []byte) { + if len(order) == 0 { + t.Fatalf("no more headers expected, got %q", key) + } + if order[0] != string(key) { + t.Fatalf("expected header %q got %q", order[0], key) + } + if values[0] != string(value) { + t.Fatalf("expected header value %q got %q", values[0], value) + } + order = order[1:] + values = values[1:] + }) +} + func TestResponseHeaderCookie(t *testing.T) { + t.Parallel() + var h ResponseHeader var c Cookie @@ -1123,7 +1572,9 @@ h.VisitAllCookie(func(key, value []byte) { var cc Cookie - cc.ParseBytes(value) + if err := cc.ParseBytes(value); err != nil { + t.Fatal(err) + } if !bytes.Equal(key, cc.Key()) { t.Fatalf("Unexpected cookie key %q. Expected %q", key, cc.Key()) } @@ -1211,9 +1662,11 @@ } func TestRequestHeaderCookie(t *testing.T) { + t.Parallel() + var h RequestHeader h.SetRequestURI("/foobar") - h.Set("Host", "foobar.com") + h.Set(HeaderHost, "foobar.com") h.SetCookie("foo", "bar") h.SetCookie("привет", "мир") @@ -1264,12 +1717,102 @@ } } +func TestResponseHeaderCookieIssue4(t *testing.T) { + t.Parallel() + + var h ResponseHeader + + c := AcquireCookie() + c.SetKey("foo") + c.SetValue("bar") + h.SetCookie(c) + + if string(h.Peek(HeaderSetCookie)) != "foo=bar" { + t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar") + } + cookieSeen := false + h.VisitAll(func(key, value []byte) { + switch string(key) { + case HeaderSetCookie: + cookieSeen = true + } + }) + if !cookieSeen { + t.Fatalf("Set-Cookie not present in VisitAll") + } + + c = AcquireCookie() + c.SetKey("foo") + h.Cookie(c) + if string(c.Value()) != "bar" { + t.Fatalf("Unexpected cookie value %q. Exepcted %q", c.Value(), "bar") + } + + if string(h.Peek(HeaderSetCookie)) != "foo=bar" { + t.Fatalf("Unexpected Set-Cookie header %q. Expected %q", h.Peek(HeaderSetCookie), "foo=bar") + } + cookieSeen = false + h.VisitAll(func(key, value []byte) { + switch string(key) { + case HeaderSetCookie: + cookieSeen = true + } + }) + if !cookieSeen { + t.Fatalf("Set-Cookie not present in VisitAll") + } +} + +func TestRequestHeaderCookieIssue313(t *testing.T) { + t.Parallel() + + var h RequestHeader + h.SetRequestURI("/") + h.Set(HeaderHost, "foobar.com") + + h.SetCookie("foo", "bar") + + if string(h.Peek(HeaderCookie)) != "foo=bar" { + t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar") + } + cookieSeen := false + h.VisitAll(func(key, value []byte) { + switch string(key) { + case HeaderCookie: + cookieSeen = true + } + }) + if !cookieSeen { + t.Fatalf("Cookie not present in VisitAll") + } + + if string(h.Cookie("foo")) != "bar" { + t.Fatalf("Unexpected cookie value %q. Exepcted %q", h.Cookie("foo"), "bar") + } + + if string(h.Peek(HeaderCookie)) != "foo=bar" { + t.Fatalf("Unexpected Cookie header %q. Expected %q", h.Peek(HeaderCookie), "foo=bar") + } + cookieSeen = false + h.VisitAll(func(key, value []byte) { + switch string(key) { + case HeaderCookie: + cookieSeen = true + } + }) + if !cookieSeen { + t.Fatalf("Cookie not present in VisitAll") + } +} + func TestRequestHeaderMethod(t *testing.T) { + t.Parallel() + // common http methods - testRequestHeaderMethod(t, "GET") - testRequestHeaderMethod(t, "POST") - testRequestHeaderMethod(t, "HEAD") - testRequestHeaderMethod(t, "DELETE") + testRequestHeaderMethod(t, MethodGet) + testRequestHeaderMethod(t, MethodPost) + testRequestHeaderMethod(t, MethodHead) + testRequestHeaderMethod(t, MethodDelete) // non-http methods testRequestHeaderMethod(t, "foobar") @@ -1297,9 +1840,11 @@ } func TestRequestHeaderSetGet(t *testing.T) { + t.Parallel() + h := &RequestHeader{} h.SetRequestURI("/aa/bbb") - h.SetMethod("POST") + h.SetMethod(MethodPost) h.Set("foo", "bar") h.Set("host", "12345") h.Set("content-type", "aaa/bbb") @@ -1311,13 +1856,13 @@ h.Set("connection", "close") expectRequestHeaderGet(t, h, "Foo", "bar") - expectRequestHeaderGet(t, h, "Host", "12345") - expectRequestHeaderGet(t, h, "Content-Type", "aaa/bbb") - expectRequestHeaderGet(t, h, "Content-Length", "1234") + expectRequestHeaderGet(t, h, HeaderHost, "12345") + expectRequestHeaderGet(t, h, HeaderContentType, "aaa/bbb") + expectRequestHeaderGet(t, h, HeaderContentLength, "1234") expectRequestHeaderGet(t, h, "USER-AGent", "aaabbb") - expectRequestHeaderGet(t, h, "Referer", "axcv") + expectRequestHeaderGet(t, h, HeaderReferer, "axcv") expectRequestHeaderGet(t, h, "baz", "xxxxx") - expectRequestHeaderGet(t, h, "Transfer-Encoding", "") + expectRequestHeaderGet(t, h, HeaderTransferEncoding, "") expectRequestHeaderGet(t, h, "connecTION", "close") if !h.ConnectionClose() { t.Fatalf("unset connection: close") @@ -1349,35 +1894,37 @@ expectRequestHeaderGet(t, &h1, "Foo", "bar") expectRequestHeaderGet(t, &h1, "HOST", "12345") - expectRequestHeaderGet(t, &h1, "Content-Type", "aaa/bbb") - expectRequestHeaderGet(t, &h1, "Content-Length", "1234") + expectRequestHeaderGet(t, &h1, HeaderContentType, "aaa/bbb") + expectRequestHeaderGet(t, &h1, HeaderContentLength, "1234") expectRequestHeaderGet(t, &h1, "USER-AGent", "aaabbb") - expectRequestHeaderGet(t, &h1, "Referer", "axcv") + expectRequestHeaderGet(t, &h1, HeaderReferer, "axcv") expectRequestHeaderGet(t, &h1, "baz", "xxxxx") - expectRequestHeaderGet(t, &h1, "Transfer-Encoding", "") - expectRequestHeaderGet(t, &h1, "Connection", "close") + expectRequestHeaderGet(t, &h1, HeaderTransferEncoding, "") + expectRequestHeaderGet(t, &h1, HeaderConnection, "close") if !h1.ConnectionClose() { t.Fatalf("unset connection: close") } } func TestResponseHeaderSetGet(t *testing.T) { + t.Parallel() + h := &ResponseHeader{} h.Set("foo", "bar") h.Set("content-type", "aaa/bbb") h.Set("connection", "close") h.Set("content-length", "1234") - h.Set("Server", "aaaa") + h.Set(HeaderServer, "aaaa") h.Set("baz", "xxxxx") - h.Set("Transfer-Encoding", "chunked") + h.Set(HeaderTransferEncoding, "chunked") expectResponseHeaderGet(t, h, "Foo", "bar") - expectResponseHeaderGet(t, h, "Content-Type", "aaa/bbb") - expectResponseHeaderGet(t, h, "Connection", "close") - expectResponseHeaderGet(t, h, "Content-Length", "1234") + expectResponseHeaderGet(t, h, HeaderContentType, "aaa/bbb") + expectResponseHeaderGet(t, h, HeaderConnection, "close") + expectResponseHeaderGet(t, h, HeaderContentLength, "1234") expectResponseHeaderGet(t, h, "seRVer", "aaaa") expectResponseHeaderGet(t, h, "baz", "xxxxx") - expectResponseHeaderGet(t, h, "Transfer-Encoding", "") + expectResponseHeaderGet(t, h, HeaderTransferEncoding, "") if h.ContentLength() != 1234 { t.Fatalf("Unexpected content-length %d. Expected %d", h.ContentLength(), 1234) @@ -1410,8 +1957,8 @@ } expectResponseHeaderGet(t, &h1, "Foo", "bar") - expectResponseHeaderGet(t, &h1, "Content-Type", "aaa/bbb") - expectResponseHeaderGet(t, &h1, "Connection", "close") + expectResponseHeaderGet(t, &h1, HeaderContentType, "aaa/bbb") + expectResponseHeaderGet(t, &h1, HeaderConnection, "close") expectResponseHeaderGet(t, &h1, "seRVer", "aaaa") expectResponseHeaderGet(t, &h1, "baz", "xxxxx") } @@ -1429,6 +1976,8 @@ } func TestResponseHeaderConnectionClose(t *testing.T) { + t.Parallel() + testResponseHeaderConnectionClose(t, true) testResponseHeaderConnectionClose(t, false) } @@ -1462,6 +2011,8 @@ } func TestRequestHeaderTooBig(t *testing.T) { + t.Parallel() + s := "GET / HTTP/1.1\r\nHost: aaa.com\r\n" + getHeaders(10500) + "\r\n" r := bytes.NewBufferString(s) br := bufio.NewReaderSize(r, 4096) @@ -1473,6 +2024,8 @@ } func TestResponseHeaderTooBig(t *testing.T) { + t.Parallel() + s := "HTTP/1.1 200 OK\r\nContent-Type: sss\r\nContent-Length: 0\r\n" + getHeaders(100500) + "\r\n" r := bytes.NewBufferString(s) br := bufio.NewReaderSize(r, 4096) @@ -1505,6 +2058,8 @@ } func TestRequestHeaderBufioPeek(t *testing.T) { + t.Parallel() + r := &bufioPeekReader{ s: "GET / HTTP/1.1\r\nHost: foobar.com\r\n" + getHeaders(10) + "\r\naaaa", } @@ -1513,11 +2068,13 @@ if err := h.Read(br); err != nil { t.Fatalf("Unexpected error when reading request: %s", err) } - verifyRequestHeader(t, h, 0, "/", "foobar.com", "", "") + verifyRequestHeader(t, h, -2, "/", "foobar.com", "", "") verifyTrailer(t, br, "aaaa") } func TestResponseHeaderBufioPeek(t *testing.T) { + t.Parallel() + r := &bufioPeekReader{ s: "HTTP/1.1 200 OK\r\nContent-Length: 10\r\nContent-Type: aaa\r\n" + getHeaders(10) + "\r\n0123456789", } @@ -1539,6 +2096,8 @@ } func TestResponseHeaderReadSuccess(t *testing.T) { + t.Parallel() + h := &ResponseHeader{} // straight order of content-length and content-type @@ -1664,6 +2223,12 @@ testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\r\nContent-Length: 123\r\n\r\nfoiaaa", 400, 123, string(defaultContentType), "foiaaa") + // no content-type and no default + h.SetNoDefaultContentType(true) + testResponseHeaderReadSuccess(t, h, "HTTP/1.1 400 OK\r\nContent-Length: 123\r\n\r\nfoiaaa", + 400, 123, "", "foiaaa") + h.SetNoDefaultContentType(false) + // no headers testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\n\r\naaaabbb", 200, -2, string(defaultContentType), "aaaabbb") @@ -1693,25 +2258,27 @@ } func TestRequestHeaderReadSuccess(t *testing.T) { + t.Parallel() + h := &RequestHeader{} // simple headers testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: google.com\r\n\r\n", - 0, "/foo/bar", "google.com", "", "", "") + -2, "/foo/bar", "google.com", "", "", "") if h.ConnectionClose() { t.Fatalf("unexpected connection: close header") } // simple headers with body testRequestHeaderReadSuccess(t, h, "GET /a/bar HTTP/1.1\r\nHost: gole.com\r\nconneCTION: close\r\n\r\nfoobar", - 0, "/a/bar", "gole.com", "", "", "foobar") + -2, "/a/bar", "gole.com", "", "", "foobar") if !h.ConnectionClose() { t.Fatalf("connection: close unset") } // ancient http protocol testRequestHeaderReadSuccess(t, h, "GET /bar HTTP/1.0\r\nHost: gole\r\n\r\npppp", - 0, "/bar", "gole", "", "", "pppp") + -2, "/bar", "gole", "", "", "pppp") if h.IsHTTP11() { t.Fatalf("ancient http protocol cannot be http/1.1") } @@ -1721,7 +2288,7 @@ // ancient http protocol with 'Connection: keep-alive' header testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.0\r\nHost: bb\r\nConnection: keep-alive\r\n\r\nxxx", - 0, "/aa", "bb", "", "", "xxx") + -2, "/aa", "bb", "", "", "xxx") if h.IsHTTP11() { t.Fatalf("ancient http protocol cannot be http/1.1") } @@ -1731,7 +2298,7 @@ // complex headers with body testRequestHeaderReadSuccess(t, h, "GET /aabar HTTP/1.1\r\nAAA: bbb\r\nHost: ole.com\r\nAA: bb\r\n\r\nzzz", - 0, "/aabar", "ole.com", "", "", "zzz") + -2, "/aabar", "ole.com", "", "", "zzz") if !h.IsHTTP11() { t.Fatalf("expecting http/1.1 protocol") } @@ -1741,7 +2308,7 @@ // lf instead of crlf testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\nHost: google.com\n\n", - 0, "/foo/bar", "google.com", "", "", "") + -2, "/foo/bar", "google.com", "", "", "") // post method testRequestHeaderReadSuccess(t, h, "POST /aaa?bbb HTTP/1.1\r\nHost: foobar.com\r\nContent-Length: 1235\r\nContent-Type: aaa\r\n\r\nabcdef", @@ -1749,11 +2316,11 @@ // zero-length headers with mixed crlf and lf testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost: aaa\r\nZero: \n: Zero-Value\n\r\nxccv", - 0, "/a", "aaa", "", "", "xccv") + -2, "/a", "aaa", "", "", "xccv") // no space after colon testRequestHeaderReadSuccess(t, h, "GET /a HTTP/1.1\nHost:aaaxd\n\nsdfds", - 0, "/a", "aaaxd", "", "", "sdfds") + -2, "/a", "aaaxd", "", "", "sdfds") // get with zero content-length testRequestHeaderReadSuccess(t, h, "GET /xxx HTTP/1.1\nHost: aaa.com\nContent-Length: 0\n\n", @@ -1761,19 +2328,19 @@ // get with non-zero content-length testRequestHeaderReadSuccess(t, h, "GET /xxx HTTP/1.1\nHost: aaa.com\nContent-Length: 123\n\n", - 0, "/xxx", "aaa.com", "", "", "") + 123, "/xxx", "aaa.com", "", "", "") // invalid case testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\nhoST: bbb.com\n\naas", - 0, "/aaa", "bbb.com", "", "", "aas") + -2, "/aaa", "bbb.com", "", "", "aas") // referer testRequestHeaderReadSuccess(t, h, "GET /asdf HTTP/1.1\nHost: aaa.com\nReferer: bb.com\n\naaa", - 0, "/asdf", "aaa.com", "bb.com", "", "aaa") + -2, "/asdf", "aaa.com", "bb.com", "", "aaa") // duplicate host testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.1\r\nHost: aaaaaa.com\r\nHost: bb.com\r\n\r\n", - 0, "/aa", "bb.com", "", "", "") + -2, "/aa", "bb.com", "", "", "") // post with duplicate content-type testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: aa\r\nContent-Type: ab\r\nContent-Length: 123\r\nContent-Type: xx\r\n\r\n", @@ -1785,43 +2352,39 @@ // non-post with content-type testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\r\nHost: bbb.com\r\nContent-Type: aaab\r\n\r\n", - 0, "/aaa", "bbb.com", "", "aaab", "") + -2, "/aaa", "bbb.com", "", "aaab", "") // non-post with content-length testRequestHeaderReadSuccess(t, h, "HEAD / HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 123\r\n\r\n", - 0, "/", "aaa.com", "", "", "") + 123, "/", "aaa.com", "", "", "") // non-post with content-type and content-length testRequestHeaderReadSuccess(t, h, "GET /aa HTTP/1.1\r\nHost: aa.com\r\nContent-Type: abd/test\r\nContent-Length: 123\r\n\r\n", - 0, "/aa", "aa.com", "", "abd/test", "") + 123, "/aa", "aa.com", "", "abd/test", "") // request uri with hostname testRequestHeaderReadSuccess(t, h, "GET http://gooGle.com/foO/%20bar?xxx#aaa HTTP/1.1\r\nHost: aa.cOM\r\n\r\ntrail", - 0, "http://gooGle.com/foO/%20bar?xxx#aaa", "aa.cOM", "", "", "trail") + -2, "http://gooGle.com/foO/%20bar?xxx#aaa", "aa.cOM", "", "", "trail") // no protocol in the first line testRequestHeaderReadSuccess(t, h, "GET /foo/bar\r\nHost: google.com\r\n\r\nisdD", - 0, "/foo/bar", "google.com", "", "", "isdD") + -2, "/foo/bar", "google.com", "", "", "isdD") // blank lines before the first line testRequestHeaderReadSuccess(t, h, "\r\n\n\r\nGET /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nsss", - 0, "/aaa", "aaa.com", "", "", "sss") + -2, "/aaa", "aaa.com", "", "", "sss") // request uri with spaces testRequestHeaderReadSuccess(t, h, "GET /foo/ bar baz HTTP/1.1\r\nHost: aa.com\r\n\r\nxxx", - 0, "/foo/ bar baz", "aa.com", "", "", "xxx") + -2, "/foo/ bar baz", "aa.com", "", "", "xxx") // no host testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nFOObar: assdfd\r\n\r\naaa", - 0, "/foo/bar", "", "", "", "aaa") + -2, "/foo/bar", "", "", "", "aaa") // no host, no headers testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\n\r\nfoobar", - 0, "/foo/bar", "", "", "", "foobar") - - // post with invalid content-length - testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty", - -2, "/a", "bb", "", "aa", "qwerty") + -2, "/foo/bar", "", "", "", "foobar") // post without content-length and content-type testRequestHeaderReadSuccess(t, h, "POST /aaa HTTP/1.1\r\nHost: aaa.com\r\n\r\nzxc", @@ -1845,6 +2408,8 @@ } func TestResponseHeaderReadError(t *testing.T) { + t.Parallel() + h := &ResponseHeader{} // incorrect first line @@ -1867,7 +2432,35 @@ testResponseHeaderReadError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n") } +func TestResponseHeaderReadErrorSecureLog(t *testing.T) { + t.Parallel() + + h := &ResponseHeader{ + secureErrorLogMessage: true, + } + + // incorrect first line + testResponseHeaderReadSecuredError(t, h, "fo") + testResponseHeaderReadSecuredError(t, h, "foobarbaz") + testResponseHeaderReadSecuredError(t, h, "HTTP/1.1") + testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 ") + testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 s") + + // non-numeric status code + testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 foobar OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n") + testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 123foobar OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n") + testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 foobar344 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n\r\n") + + // no headers + testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 200 OK\r\n") + + // no trailing crlf + testResponseHeaderReadSecuredError(t, h, "HTTP/1.1 200 OK\r\nContent-Length: 123\r\nContent-Type: text/html\r\n") +} + func TestRequestHeaderReadError(t *testing.T) { + t.Parallel() + h := &RequestHeader{} // incorrect first line @@ -1878,6 +2471,28 @@ // missing RequestURI testRequestHeaderReadError(t, h, "GET HTTP/1.1\r\nHost: google.com\r\n\r\n") + + // post with invalid content-length + testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty") +} + +func TestRequestHeaderReadSecuredError(t *testing.T) { + t.Parallel() + + h := &RequestHeader{ + secureErrorLogMessage: true, + } + + // incorrect first line + testRequestHeaderReadSecuredError(t, h, "fo") + testRequestHeaderReadSecuredError(t, h, "GET ") + testRequestHeaderReadSecuredError(t, h, "GET / HTTP/1.1\r") + + // missing RequestURI + testRequestHeaderReadSecuredError(t, h, "GET HTTP/1.1\r\nHost: google.com\r\n\r\n") + + // post with invalid content-length + testRequestHeaderReadSecuredError(t, h, "POST /a HTTP/1.1\r\nHost: bb\r\nContent-Type: aa\r\nContent-Length: dff\r\n\r\nqwerty") } func testResponseHeaderReadError(t *testing.T, h *ResponseHeader, headers string) { @@ -1887,7 +2502,21 @@ if err == nil { t.Fatalf("Expecting error when reading response header %q", headers) } + // make sure response header works after error + testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 12345\r\n\r\nsss", + 200, 12345, "foo/bar", "sss") +} +func testResponseHeaderReadSecuredError(t *testing.T, h *ResponseHeader, headers string) { + r := bytes.NewBufferString(headers) + br := bufio.NewReader(r) + err := h.Read(br) + if err == nil { + t.Fatalf("Expecting error when reading response header %q", headers) + } + if strings.Contains(err.Error(), headers) { + t.Fatalf("Not expecting header content in err %q", err) + } // make sure response header works after error testResponseHeaderReadSuccess(t, h, "HTTP/1.1 200 OK\r\nContent-Type: foo/bar\r\nContent-Length: 12345\r\n\r\nsss", 200, 12345, "foo/bar", "sss") @@ -1903,7 +2532,22 @@ // make sure request header works after error testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaaa\r\n\r\nxxx", - 0, "/foo/bar", "aaaa", "", "", "xxx") + -2, "/foo/bar", "aaaa", "", "", "xxx") +} + +func testRequestHeaderReadSecuredError(t *testing.T, h *RequestHeader, headers string) { + r := bytes.NewBufferString(headers) + br := bufio.NewReader(r) + err := h.Read(br) + if err == nil { + t.Fatalf("Expecting error when reading request header %q", headers) + } + if strings.Contains(err.Error(), headers) { + t.Fatalf("Not expecting header content in err %q", err) + } + // make sure request header works after error + testRequestHeaderReadSuccess(t, h, "GET /foo/bar HTTP/1.1\r\nHost: aaaa\r\n\r\nxxx", + -2, "/foo/bar", "aaaa", "", "", "xxx") } func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers string, expectedStatusCode, expectedContentLength int, @@ -1937,8 +2581,14 @@ if h.ContentLength() != expectedContentLength { t.Fatalf("Unexpected content length %d. Expected %d", h.ContentLength(), expectedContentLength) } - if string(h.Peek("Content-Type")) != expectedContentType { - t.Fatalf("Unexpected content type %q. Expected %q", h.Peek("Content-Type"), expectedContentType) + if string(h.Peek(HeaderContentType)) != expectedContentType { + t.Fatalf("Unexpected content type %q. Expected %q", h.Peek(HeaderContentType), expectedContentType) + } +} + +func verifyResponseHeaderConnection(t *testing.T, h *ResponseHeader, expectConnection string) { + if string(h.Peek(HeaderConnection)) != expectConnection { + t.Fatalf("Unexpected Connection %q. Expected %q", h.Peek(HeaderConnection), expectConnection) } } @@ -1950,14 +2600,14 @@ if string(h.RequestURI()) != expectedRequestURI { t.Fatalf("Unexpected RequestURI %q. Expected %q", h.RequestURI(), expectedRequestURI) } - if string(h.Peek("Host")) != expectedHost { - t.Fatalf("Unexpected host %q. Expected %q", h.Peek("Host"), expectedHost) + if string(h.Peek(HeaderHost)) != expectedHost { + t.Fatalf("Unexpected host %q. Expected %q", h.Peek(HeaderHost), expectedHost) } - if string(h.Peek("Referer")) != expectedReferer { - t.Fatalf("Unexpected referer %q. Expected %q", h.Peek("Referer"), expectedReferer) + if string(h.Peek(HeaderReferer)) != expectedReferer { + t.Fatalf("Unexpected referer %q. Expected %q", h.Peek(HeaderReferer), expectedReferer) } - if string(h.Peek("Content-Type")) != expectedContentType { - t.Fatalf("Unexpected content-type %q. Expected %q", h.Peek("Content-Type"), expectedContentType) + if string(h.Peek(HeaderContentType)) != expectedContentType { + t.Fatalf("Unexpected content-type %q. Expected %q", h.Peek(HeaderContentType), expectedContentType) } } diff -Nru golang-github-valyala-fasthttp-20160617/header_timing_test.go golang-github-valyala-fasthttp-1.31.0/header_timing_test.go --- golang-github-valyala-fasthttp-20160617/header_timing_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/header_timing_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -4,7 +4,10 @@ "bufio" "bytes" "io" + "strconv" "testing" + + "github.com/valyala/bytebufferpool" ) var strFoobar = []byte("foobar.com") @@ -65,7 +68,7 @@ h.SetHost("foobar.com") h.SetUserAgent("aaa.bbb") h.SetReferer("http://google.com/aaa/bbb") - var w ByteBuffer + var w bytebufferpool.ByteBuffer for pb.Next() { if _, err := h.WriteTo(&w); err != nil { b.Fatalf("unexpected error when writing header: %s", err) @@ -83,7 +86,7 @@ h.SetContentLength(1256) h.SetServer("aaa 1/2.3") h.Set("Test", "1.2.3") - var w ByteBuffer + var w bytebufferpool.ByteBuffer for pb.Next() { if _, err := h.WriteTo(&w); err != nil { b.Fatalf("unexpected error when writing header: %s", err) @@ -144,3 +147,43 @@ } }) } + +func BenchmarkRemoveNewLines(b *testing.B) { + type testcase struct { + value string + expectedValue string + } + + var testcases = []testcase{ + {value: "MaliciousValue", expectedValue: "MaliciousValue"}, + {value: "MaliciousValue\r\n", expectedValue: "MaliciousValue "}, + {value: "Malicious\nValue", expectedValue: "Malicious Value"}, + {value: "Malicious\rValue", expectedValue: "Malicious Value"}, + } + + for i, tcase := range testcases { + caseName := strconv.FormatInt(int64(i), 10) + b.Run(caseName, func(subB *testing.B) { + subB.ReportAllocs() + var h RequestHeader + for i := 0; i < subB.N; i++ { + h.Set("Test", tcase.value) + } + subB.StopTimer() + actualValue := string(h.Peek("Test")) + + if actualValue != tcase.expectedValue { + subB.Errorf("unexpected value, got: %+v", actualValue) + } + }) + } +} + +func BenchmarkRequestHeaderIsGet(b *testing.B) { + req := &RequestHeader{method: []byte(MethodGet)} + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + req.IsGet() + } + }) +} diff -Nru golang-github-valyala-fasthttp-20160617/http.go golang-github-valyala-fasthttp-1.31.0/http.go --- golang-github-valyala-fasthttp-20160617/http.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/http.go 2021-10-09 18:39:05.000000000 +0000 @@ -3,12 +3,18 @@ import ( "bufio" "bytes" + "compress/gzip" + "encoding/base64" "errors" "fmt" "io" "mime/multipart" + "net" "os" "sync" + "time" + + "github.com/valyala/bytebufferpool" ) // Request represents HTTP request. @@ -18,7 +24,7 @@ // // Request instance MUST NOT be used from concurrently running goroutines. type Request struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck // Request header // @@ -30,14 +36,26 @@ bodyStream io.Reader w requestBodyWriter - body *ByteBuffer + body *bytebufferpool.ByteBuffer + bodyRaw []byte multipartForm *multipart.Form multipartFormBoundary string + secureErrorLogMessage bool // Group bool members in order to reduce Request object size. parsedURI bool parsedPostArgs bool + + keepBodyBuffer bool + + // Used by Server to indicate the request was received on a HTTPS endpoint. + // Client/HostClient shouldn't use this field but should depend on the uri.scheme instead. + isTLS bool + + // Request timeout. Usually set by DoDeadline or DoTimeout + // if <= 0, means not set + timeout time.Duration } // Response represents HTTP response. @@ -47,16 +65,21 @@ // // Response instance MUST NOT be used from concurrently running goroutines. type Response struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck // Response header // // Copying Header by value is forbidden. Use pointer to Header instead. Header ResponseHeader + // Flush headers as soon as possible without waiting for first body bytes. + // Relevant for bodyStream only. + ImmediateHeaderFlush bool + bodyStream io.Reader w responseBodyWriter - body *ByteBuffer + body *bytebufferpool.ByteBuffer + bodyRaw []byte // Response.Read() skips reading body if set to true. // Use it for reading HEAD responses. @@ -65,9 +88,13 @@ // Use it for writing HEAD responses. SkipBody bool - // This is a hackish field for client implementation, which allows - // avoiding body copying. - keepBodyBuffer bool + keepBodyBuffer bool + secureErrorLogMessage bool + + // Remote TCPAddr from concurrently net.Conn + raddr net.Addr + // Local TCPAddr from concurrently net.Conn + laddr net.Addr } // SetHost sets host for the request. @@ -274,13 +301,34 @@ return len(p), nil } +func (resp *Response) parseNetConn(conn net.Conn) { + resp.raddr = conn.RemoteAddr() + resp.laddr = conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. The Addr returned is shared +// by all invocations of RemoteAddr, so do not modify it. +func (resp *Response) RemoteAddr() net.Addr { + return resp.raddr +} + +// LocalAddr returns the local network address. The Addr returned is shared +// by all invocations of LocalAddr, so do not modify it. +func (resp *Response) LocalAddr() net.Addr { + return resp.laddr +} + // Body returns response body. +// +// The returned value is valid until the response is released, +// either though ReleaseResponse or your request handler returning. +// Do not store references to returned value. Make copies instead. func (resp *Response) Body() []byte { if resp.bodyStream != nil { bodyBuf := resp.bodyBuffer() bodyBuf.Reset() _, err := copyZeroAlloc(bodyBuf, resp.bodyStream) - resp.closeBodyStream() + resp.closeBodyStream() //nolint:errcheck if err != nil { bodyBuf.SetString(err.Error()) } @@ -289,6 +337,9 @@ } func (resp *Response) bodyBytes() []byte { + if resp.bodyRaw != nil { + return resp.bodyRaw + } if resp.body == nil { return nil } @@ -296,29 +347,43 @@ } func (req *Request) bodyBytes() []byte { + if req.bodyRaw != nil { + return req.bodyRaw + } + if req.bodyStream != nil { + bodyBuf := req.bodyBuffer() + bodyBuf.Reset() + _, err := copyZeroAlloc(bodyBuf, req.bodyStream) + req.closeBodyStream() //nolint:errcheck + if err != nil { + bodyBuf.SetString(err.Error()) + } + } if req.body == nil { return nil } return req.body.B } -func (resp *Response) bodyBuffer() *ByteBuffer { +func (resp *Response) bodyBuffer() *bytebufferpool.ByteBuffer { if resp.body == nil { - resp.body = responseBodyPool.Acquire() + resp.body = responseBodyPool.Get() } + resp.bodyRaw = nil return resp.body } -func (req *Request) bodyBuffer() *ByteBuffer { +func (req *Request) bodyBuffer() *bytebufferpool.ByteBuffer { if req.body == nil { - req.body = requestBodyPool.Acquire() + req.body = requestBodyPool.Get() } + req.bodyRaw = nil return req.body } var ( - requestBodyPool byteBufferPool - responseBodyPool byteBufferPool + responseBodyPool bytebufferpool.Pool + requestBodyPool bytebufferpool.Pool ) // BodyGunzip returns un-gzipped body data. @@ -340,7 +405,7 @@ } func gunzipData(p []byte) ([]byte, error) { - var bb ByteBuffer + var bb bytebufferpool.ByteBuffer _, err := WriteGunzip(&bb, p) if err != nil { return nil, err @@ -348,6 +413,33 @@ return bb.B, nil } +// BodyUnbrotli returns un-brotlied body data. +// +// This method may be used if the request header contains +// 'Content-Encoding: br' for reading un-brotlied body. +// Use Body for reading brotlied request body. +func (req *Request) BodyUnbrotli() ([]byte, error) { + return unBrotliData(req.Body()) +} + +// BodyUnbrotli returns un-brotlied body data. +// +// This method may be used if the response header contains +// 'Content-Encoding: br' for reading un-brotlied body. +// Use Body for reading brotlied response body. +func (resp *Response) BodyUnbrotli() ([]byte, error) { + return unBrotliData(resp.Body()) +} + +func unBrotliData(p []byte) ([]byte, error) { + var bb bytebufferpool.ByteBuffer + _, err := WriteUnbrotli(&bb, p) + if err != nil { + return nil, err + } + return bb.B, nil +} + // BodyInflate returns inflated body data. // // This method may be used if the response header contains @@ -366,8 +458,12 @@ return inflateData(resp.Body()) } +func (ctx *RequestCtx) RequestBodyStream() io.Reader { + return ctx.Request.bodyStream +} + func inflateData(p []byte) ([]byte, error) { - var bb ByteBuffer + var bb bytebufferpool.ByteBuffer _, err := WriteInflate(&bb, p) if err != nil { return nil, err @@ -379,7 +475,7 @@ func (req *Request) BodyWriteTo(w io.Writer) error { if req.bodyStream != nil { _, err := copyZeroAlloc(w, req.bodyStream) - req.closeBodyStream() + req.closeBodyStream() //nolint:errcheck return err } if req.onlyMultipartForm() { @@ -393,7 +489,7 @@ func (resp *Response) BodyWriteTo(w io.Writer) error { if resp.bodyStream != nil { _, err := copyZeroAlloc(w, resp.bodyStream) - resp.closeBodyStream() + resp.closeBodyStream() //nolint:errcheck return err } _, err := w.Write(resp.bodyBytes()) @@ -404,50 +500,75 @@ // // It is safe re-using p after the function returns. func (resp *Response) AppendBody(p []byte) { - resp.AppendBodyString(b2s(p)) + resp.closeBodyStream() //nolint:errcheck + resp.bodyBuffer().Write(p) //nolint:errcheck } // AppendBodyString appends s to response body. func (resp *Response) AppendBodyString(s string) { - resp.closeBodyStream() - resp.bodyBuffer().WriteString(s) + resp.closeBodyStream() //nolint:errcheck + resp.bodyBuffer().WriteString(s) //nolint:errcheck } // SetBody sets response body. // // It is safe re-using body argument after the function returns. func (resp *Response) SetBody(body []byte) { - resp.SetBodyString(b2s(body)) + resp.closeBodyStream() //nolint:errcheck + bodyBuf := resp.bodyBuffer() + bodyBuf.Reset() + bodyBuf.Write(body) //nolint:errcheck } // SetBodyString sets response body. func (resp *Response) SetBodyString(body string) { - resp.closeBodyStream() + resp.closeBodyStream() //nolint:errcheck bodyBuf := resp.bodyBuffer() bodyBuf.Reset() - bodyBuf.WriteString(body) + bodyBuf.WriteString(body) //nolint:errcheck } // ResetBody resets response body. func (resp *Response) ResetBody() { - resp.closeBodyStream() + resp.bodyRaw = nil + resp.closeBodyStream() //nolint:errcheck if resp.body != nil { if resp.keepBodyBuffer { resp.body.Reset() } else { - responseBodyPool.Release(resp.body) + responseBodyPool.Put(resp.body) resp.body = nil } } } +// SetBodyRaw sets response body, but without copying it. +// +// From this point onward the body argument must not be changed. +func (resp *Response) SetBodyRaw(body []byte) { + resp.ResetBody() + resp.bodyRaw = body +} + +// SetBodyRaw sets response body, but without copying it. +// +// From this point onward the body argument must not be changed. +func (req *Request) SetBodyRaw(body []byte) { + req.ResetBody() + req.bodyRaw = body +} + // ReleaseBody retires the response body if it is greater than "size" bytes. // // This permits GC to reclaim the large buffer. If used, must be before // ReleaseResponse. +// +// Use this method only if you really understand how it works. +// The majority of workloads don't need this method. func (resp *Response) ReleaseBody(size int) { + resp.bodyRaw = nil if cap(resp.body.B) > size { - resp.closeBodyStream() + resp.closeBodyStream() //nolint:errcheck resp.body = nil } } @@ -456,23 +577,75 @@ // // This permits GC to reclaim the large buffer. If used, must be before // ReleaseRequest. +// +// Use this method only if you really understand how it works. +// The majority of workloads don't need this method. func (req *Request) ReleaseBody(size int) { + req.bodyRaw = nil if cap(req.body.B) > size { - req.closeBodyStream() + req.closeBodyStream() //nolint:errcheck req.body = nil } } -// Body returns request body. -func (req *Request) Body() []byte { +// SwapBody swaps response body with the given body and returns +// the previous response body. +// +// It is forbidden to use the body passed to SwapBody after +// the function returns. +func (resp *Response) SwapBody(body []byte) []byte { + bb := resp.bodyBuffer() + + if resp.bodyStream != nil { + bb.Reset() + _, err := copyZeroAlloc(bb, resp.bodyStream) + resp.closeBodyStream() //nolint:errcheck + if err != nil { + bb.Reset() + bb.SetString(err.Error()) + } + } + + resp.bodyRaw = nil + + oldBody := bb.B + bb.B = body + return oldBody +} + +// SwapBody swaps request body with the given body and returns +// the previous request body. +// +// It is forbidden to use the body passed to SwapBody after +// the function returns. +func (req *Request) SwapBody(body []byte) []byte { + bb := req.bodyBuffer() + if req.bodyStream != nil { - bodyBuf := req.bodyBuffer() - bodyBuf.Reset() - _, err := copyZeroAlloc(bodyBuf, req.bodyStream) - req.closeBodyStream() + bb.Reset() + _, err := copyZeroAlloc(bb, req.bodyStream) + req.closeBodyStream() //nolint:errcheck if err != nil { - bodyBuf.SetString(err.Error()) + bb.Reset() + bb.SetString(err.Error()) } + } + + req.bodyRaw = nil + + oldBody := bb.B + bb.B = body + return oldBody +} + +// Body returns request body. +// +// The returned value is valid until the request is released, +// either though ReleaseRequest or your request handler returning. +// Do not store references to returned value. Make copies instead. +func (req *Request) Body() []byte { + if req.bodyRaw != nil { + return req.bodyRaw } else if req.onlyMultipartForm() { body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary) if err != nil { @@ -487,44 +660,58 @@ // // It is safe re-using p after the function returns. func (req *Request) AppendBody(p []byte) { - req.AppendBodyString(b2s(p)) + req.RemoveMultipartFormFiles() + req.closeBodyStream() //nolint:errcheck + req.bodyBuffer().Write(p) //nolint:errcheck } // AppendBodyString appends s to request body. func (req *Request) AppendBodyString(s string) { req.RemoveMultipartFormFiles() - req.closeBodyStream() - req.bodyBuffer().WriteString(s) + req.closeBodyStream() //nolint:errcheck + req.bodyBuffer().WriteString(s) //nolint:errcheck } // SetBody sets request body. // // It is safe re-using body argument after the function returns. func (req *Request) SetBody(body []byte) { - req.SetBodyString(b2s(body)) + req.RemoveMultipartFormFiles() + req.closeBodyStream() //nolint:errcheck + req.bodyBuffer().Set(body) } // SetBodyString sets request body. func (req *Request) SetBodyString(body string) { req.RemoveMultipartFormFiles() - req.closeBodyStream() + req.closeBodyStream() //nolint:errcheck req.bodyBuffer().SetString(body) } // ResetBody resets request body. func (req *Request) ResetBody() { + req.bodyRaw = nil req.RemoveMultipartFormFiles() - req.closeBodyStream() + req.closeBodyStream() //nolint:errcheck if req.body != nil { - requestBodyPool.Release(req.body) - req.body = nil + if req.keepBodyBuffer { + req.body.Reset() + } else { + requestBodyPool.Put(req.body) + req.body = nil + } } } // CopyTo copies req contents to dst except of body stream. func (req *Request) CopyTo(dst *Request) { req.copyToSkipBody(dst) - if req.body != nil { + if req.bodyRaw != nil { + dst.bodyRaw = req.bodyRaw + if dst.body != nil { + dst.body.Reset() + } + } else if req.body != nil { dst.bodyBuffer().Set(req.body.B) } else if dst.body != nil { dst.body.Reset() @@ -540,6 +727,7 @@ req.postArgs.CopyTo(&dst.postArgs) dst.parsedPostArgs = req.parsedPostArgs + dst.isTLS = req.isTLS // do not copy multipartForm - it will be automatically // re-created on the first call to MultipartForm. @@ -548,7 +736,12 @@ // CopyTo copies resp contents to dst except of body stream. func (resp *Response) CopyTo(dst *Response) { resp.copyToSkipBody(dst) - if resp.body != nil { + if resp.bodyRaw != nil { + dst.bodyRaw = resp.bodyRaw + if dst.body != nil { + dst.body.Reset() + } + } else if resp.body != nil { dst.bodyBuffer().Set(resp.body.B) } else if dst.body != nil { dst.body.Reset() @@ -559,31 +752,35 @@ dst.Reset() resp.Header.CopyTo(&dst.Header) dst.SkipBody = resp.SkipBody + dst.raddr = resp.raddr + dst.laddr = resp.laddr } func swapRequestBody(a, b *Request) { a.body, b.body = b.body, a.body + a.bodyRaw, b.bodyRaw = b.bodyRaw, a.bodyRaw a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream } func swapResponseBody(a, b *Response) { a.body, b.body = b.body, a.body + a.bodyRaw, b.bodyRaw = b.bodyRaw, a.bodyRaw a.bodyStream, b.bodyStream = b.bodyStream, a.bodyStream } // URI returns request URI func (req *Request) URI() *URI { - req.parseURI() + req.parseURI() //nolint:errcheck return &req.uri } -func (req *Request) parseURI() { +func (req *Request) parseURI() error { if req.parsedURI { - return + return nil } req.parsedURI = true - req.uri.parseQuick(req.Header.RequestURI(), &req.Header) + return req.uri.parse(req.Header.Host(), req.Header.RequestURI(), req.isTLS) } // PostArgs returns POST arguments. @@ -625,28 +822,47 @@ return nil, ErrNoMultipartForm } + var err error ce := req.Header.peek(strContentEncoding) - body := req.bodyBytes() - if bytes.Equal(ce, strGzip) { - // Do not care about memory usage here. - var err error - if body, err = AppendGunzipBytes(nil, body); err != nil { - return nil, fmt.Errorf("cannot gunzip request body: %s", err) + + if req.bodyStream != nil { + bodyStream := req.bodyStream + if bytes.Equal(ce, strGzip) { + // Do not care about memory usage here. + if bodyStream, err = gzip.NewReader(bodyStream); err != nil { + return nil, fmt.Errorf("cannot gunzip request body: %s", err) + } + } else if len(ce) > 0 { + return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce) } - } else if len(ce) > 0 { - return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce) - } - f, err := readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body)) - if err != nil { - return nil, err + mr := multipart.NewReader(bodyStream, req.multipartFormBoundary) + req.multipartForm, err = mr.ReadForm(8 * 1024) + if err != nil { + return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err) + } + } else { + body := req.bodyBytes() + if bytes.Equal(ce, strGzip) { + // Do not care about memory usage here. + if body, err = AppendGunzipBytes(nil, body); err != nil { + return nil, fmt.Errorf("cannot gunzip request body: %s", err) + } + } else if len(ce) > 0 { + return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce) + } + + req.multipartForm, err = readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body)) + if err != nil { + return nil, err + } } - req.multipartForm = f - return f, nil + + return req.multipartForm, nil } func marshalMultipartForm(f *multipart.Form, boundary string) ([]byte, error) { - var buf ByteBuffer + var buf bytebufferpool.ByteBuffer if err := WriteMultipartForm(&buf, f, boundary); err != nil { return nil, err } @@ -657,7 +873,7 @@ // boundary to w. func WriteMultipartForm(w io.Writer, f *multipart.Form, boundary string) error { // Do not care about memory allocations here, since multipart - // form processing is slooow. + // form processing is slow. if len(boundary) == 0 { panic("BUG: form boundary cannot be empty") } @@ -679,7 +895,7 @@ // marshal files for k, fvv := range f.File { for _, fv := range fvv { - vw, err := mw.CreateFormFile(k, fv.Filename) + vw, err := mw.CreatePart(fv.Header) if err != nil { return fmt.Errorf("cannot create form file %q (%q): %s", k, fv.Filename, err) } @@ -709,7 +925,7 @@ // in multipart/form-data requests. if size <= 0 { - panic(fmt.Sprintf("BUG: form size must be greater than 0. Given %d", size)) + return nil, fmt.Errorf("form size must be greater than 0. Given %d", size) } lr := io.LimitReader(r, int64(size)) mr := multipart.NewReader(lr, boundary) @@ -724,6 +940,7 @@ func (req *Request) Reset() { req.Header.Reset() req.resetSkipHeader() + req.timeout = 0 } func (req *Request) resetSkipHeader() { @@ -732,6 +949,7 @@ req.parsedURI = false req.postArgs.Reset() req.parsedPostArgs = false + req.isTLS = false } // RemoveMultipartFormFiles removes multipart/form-data temporary files @@ -740,7 +958,7 @@ if req.multipartForm != nil { // Do not check for error, since these files may be deleted or moved // to new places by user code. - req.multipartForm.RemoveAll() + req.multipartForm.RemoveAll() //nolint:errcheck req.multipartForm = nil } req.multipartFormBoundary = "" @@ -751,6 +969,9 @@ resp.Header.Reset() resp.resetSkipHeader() resp.SkipBody = false + resp.raddr = nil + resp.laddr = nil + resp.ImmediateHeaderFlush = false } func (resp *Response) resetSkipHeader() { @@ -778,7 +999,9 @@ const defaultMaxInMemoryFileSize = 16 * 1024 * 1024 -var errGetOnly = errors.New("non-GET request received") +// ErrGetOnly is returned when server expects only GET requests, +// but some other type of request came (Server.GetOnly option is true). +var ErrGetOnly = errors.New("non-GET request received") // ReadLimitBody reads request from the given r, limiting the body size. // @@ -799,23 +1022,40 @@ // // io.EOF is returned if r is closed before reading the first header byte. func (req *Request) ReadLimitBody(r *bufio.Reader, maxBodySize int) error { - return req.readLimitBody(r, maxBodySize, false) -} - -func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool) error { req.resetSkipHeader() - err := req.Header.Read(r) - if err != nil { + if err := req.Header.Read(r); err != nil { return err } + + return req.readLimitBody(r, maxBodySize, false, true) +} + +func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error { + // Do not reset the request here - the caller must reset it before + // calling this method. + if getOnly && !req.Header.IsGet() { - return errGetOnly + return ErrGetOnly } - if req.Header.noBody() { + if req.MayContinue() { + // 'Expect: 100-continue' header found. Let the caller deciding + // whether to read request body or + // to return StatusExpectationFailed. return nil } + return req.ContinueReadBody(r, maxBodySize, preParseMultipartForm) +} + +func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error { + // Do not reset the request here - the caller must reset it before + // calling this method. + + if getOnly && !req.Header.IsGet() { + return ErrGetOnly + } + if req.MayContinue() { // 'Expect: 100-continue' header found. Let the caller deciding // whether to read request body or @@ -823,7 +1063,7 @@ return nil } - return req.ContinueReadBody(r, maxBodySize) + return req.ContinueReadBodyStream(r, maxBodySize, preParseMultipartForm) } // MayContinue returns true if the request contains @@ -847,24 +1087,26 @@ // // If maxBodySize > 0 and the body size exceeds maxBodySize, // then ErrBodyTooLarge is returned. -func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int) error { +func (req *Request) ContinueReadBody(r *bufio.Reader, maxBodySize int, preParseMultipartForm ...bool) error { var err error - contentLength := req.Header.ContentLength() + contentLength := req.Header.realContentLength() if contentLength > 0 { if maxBodySize > 0 && contentLength > maxBodySize { return ErrBodyTooLarge } - // Pre-read multipart form data of known length. - // This way we limit memory usage for large file uploads, since their contents - // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. - req.multipartFormBoundary = string(req.Header.MultipartFormBoundary()) - if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 { - req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize) - if err != nil { - req.Reset() + if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] { + // Pre-read multipart form data of known length. + // This way we limit memory usage for large file uploads, since their contents + // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. + req.multipartFormBoundary = string(req.Header.MultipartFormBoundary()) + if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 { + req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize) + if err != nil { + req.Reset() + } + return err } - return err } } @@ -873,7 +1115,10 @@ // the end of body is determined by connection close. // So just ignore request body for requests without // 'Content-Length' and 'Transfer-Encoding' headers. - req.Header.SetContentLength(0) + // refer to https://tools.ietf.org/html/rfc7230#section-3.3.2 + if !req.Header.ignoreBody() { + req.Header.SetContentLength(0) + } return nil } @@ -888,6 +1133,66 @@ return nil } +// ContinueReadBody reads request body if request header contains +// 'Expect: 100-continue'. +// +// The caller must send StatusContinue response before calling this method. +// +// If maxBodySize > 0 and the body size exceeds maxBodySize, +// then ErrBodyTooLarge is returned. +func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, preParseMultipartForm ...bool) error { + var err error + contentLength := req.Header.realContentLength() + if contentLength > 0 { + if len(preParseMultipartForm) == 0 || preParseMultipartForm[0] { + // Pre-read multipart form data of known length. + // This way we limit memory usage for large file uploads, since their contents + // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize. + req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary()) + if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 { + req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize) + if err != nil { + req.Reset() + } + return err + } + } + } + + if contentLength == -2 { + // identity body has no sense for http requests, since + // the end of body is determined by connection close. + // So just ignore request body for requests without + // 'Content-Length' and 'Transfer-Encoding' headers. + req.Header.SetContentLength(0) + return nil + } + + bodyBuf := req.bodyBuffer() + bodyBuf.Reset() + bodyBuf.B, err = readBodyWithStreaming(r, contentLength, maxBodySize, bodyBuf.B) + if err != nil { + if err == ErrBodyTooLarge { + req.Header.SetContentLength(contentLength) + req.body = bodyBuf + req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength) + return nil + } + if err == errChunkedStream { + req.body = bodyBuf + req.bodyStream = acquireRequestStream(bodyBuf, r, -1) + return nil + } + req.Reset() + return err + } + + req.body = bodyBuf + req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength) + req.Header.SetContentLength(contentLength) + return nil +} + // Read reads response (including body) from the given r. // // io.EOF is returned if r is closed before reading the first header byte. @@ -919,7 +1224,6 @@ bodyBuf.Reset() bodyBuf.B, err = readBody(r, resp.Header.ContentLength(), maxBodySize, bodyBuf.B) if err != nil { - resp.Reset() return err } resp.Header.SetContentLength(len(bodyBuf.B)) @@ -1024,6 +1328,25 @@ } req.Header.SetHostBytes(host) req.Header.SetRequestURIBytes(uri.RequestURI()) + + if len(uri.username) > 0 { + // RequestHeader.SetBytesKV only uses RequestHeader.bufKV.key + // So we are free to use RequestHeader.bufKV.value as a scratch pad for + // the base64 encoding. + nl := len(uri.username) + len(uri.password) + 1 + nb := nl + len(strBasicSpace) + tl := nb + base64.StdEncoding.EncodedLen(nl) + if tl > cap(req.Header.bufKV.value) { + req.Header.bufKV.value = make([]byte, 0, tl) + } + buf := req.Header.bufKV.value[:0] + buf = append(buf, uri.username...) + buf = append(buf, strColon...) + buf = append(buf, uri.password...) + buf = append(buf, strBasicSpace...) + base64.StdEncoding.Encode(buf[nb:tl], buf[:nl]) + req.Header.SetBytesKV(strAuthorization, buf[nl:tl]) + } } if req.bodyStream != nil { @@ -1040,8 +1363,12 @@ req.Header.SetMultipartFormBoundary(req.multipartFormBoundary) } - hasBody := !req.Header.noBody() - if hasBody { + hasBody := false + if len(body) == 0 { + body = req.postArgs.QueryString() + } + if len(body) != 0 || !req.Header.ignoreBody() { + hasBody = true req.Header.SetContentLength(len(body)) } if err = req.Header.Write(w); err != nil { @@ -1050,6 +1377,9 @@ if hasBody { _, err = w.Write(body) } else if len(body) > 0 { + if req.secureErrorLogMessage { + return fmt.Errorf("non-zero body for non-POST request") + } return fmt.Errorf("non-zero body for non-POST request. body=%q", body) } return err @@ -1073,6 +1403,7 @@ // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression +// * CompressHuffmanOnly // // The method gzips response body and sets 'Content-Encoding: gzip' // header before writing response to w. @@ -1103,6 +1434,7 @@ // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression +// * CompressHuffmanOnly // // The method deflates response body and sets 'Content-Encoding: deflate' // header before writing response to w. @@ -1115,66 +1447,198 @@ return resp.Write(w) } +func (resp *Response) brotliBody(level int) error { + if len(resp.Header.peek(strContentEncoding)) > 0 { + // It looks like the body is already compressed. + // Do not compress it again. + return nil + } + + if !resp.Header.isCompressibleContentType() { + // The content-type cannot be compressed. + return nil + } + + if resp.bodyStream != nil { + // Reset Content-Length to -1, since it is impossible + // to determine body size beforehand of streamed compression. + // For https://github.com/valyala/fasthttp/issues/176 . + resp.Header.SetContentLength(-1) + + // Do not care about memory allocations here, since brotli is slow + // and allocates a lot of memory by itself. + bs := resp.bodyStream + resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) { + zw := acquireStacklessBrotliWriter(sw, level) + fw := &flushWriter{ + wf: zw, + bw: sw, + } + copyZeroAlloc(fw, bs) //nolint:errcheck + releaseStacklessBrotliWriter(zw, level) + if bsc, ok := bs.(io.Closer); ok { + bsc.Close() + } + }) + } else { + bodyBytes := resp.bodyBytes() + if len(bodyBytes) < minCompressLen { + // There is no sense in spending CPU time on small body compression, + // since there is a very high probability that the compressed + // body size will be bigger than the original body size. + return nil + } + w := responseBodyPool.Get() + w.B = AppendBrotliBytesLevel(w.B, bodyBytes, level) + + // Hack: swap resp.body with w. + if resp.body != nil { + responseBodyPool.Put(resp.body) + } + resp.body = w + resp.bodyRaw = nil + } + resp.Header.SetCanonical(strContentEncoding, strBr) + return nil +} + func (resp *Response) gzipBody(level int) error { - // Do not care about memory allocations here, since gzip is slow - // and allocates a lot of memory by itself. + if len(resp.Header.peek(strContentEncoding)) > 0 { + // It looks like the body is already compressed. + // Do not compress it again. + return nil + } + + if !resp.Header.isCompressibleContentType() { + // The content-type cannot be compressed. + return nil + } + if resp.bodyStream != nil { + // Reset Content-Length to -1, since it is impossible + // to determine body size beforehand of streamed compression. + // For https://github.com/valyala/fasthttp/issues/176 . + resp.Header.SetContentLength(-1) + + // Do not care about memory allocations here, since gzip is slow + // and allocates a lot of memory by itself. bs := resp.bodyStream resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) { - zw := acquireGzipWriter(sw, level) - copyZeroAlloc(zw, bs) - releaseGzipWriter(zw) + zw := acquireStacklessGzipWriter(sw, level) + fw := &flushWriter{ + wf: zw, + bw: sw, + } + copyZeroAlloc(fw, bs) //nolint:errcheck + releaseStacklessGzipWriter(zw, level) if bsc, ok := bs.(io.Closer); ok { bsc.Close() } }) } else { - w := responseBodyPool.Acquire() - zw := acquireGzipWriter(w, level) - _, err := zw.Write(resp.bodyBytes()) - releaseGzipWriter(zw) - if err != nil { - return err + bodyBytes := resp.bodyBytes() + if len(bodyBytes) < minCompressLen { + // There is no sense in spending CPU time on small body compression, + // since there is a very high probability that the compressed + // body size will be bigger than the original body size. + return nil } + w := responseBodyPool.Get() + w.B = AppendGzipBytesLevel(w.B, bodyBytes, level) // Hack: swap resp.body with w. - responseBodyPool.Release(resp.body) + if resp.body != nil { + responseBodyPool.Put(resp.body) + } resp.body = w + resp.bodyRaw = nil } resp.Header.SetCanonical(strContentEncoding, strGzip) return nil } func (resp *Response) deflateBody(level int) error { - // Do not care about memory allocations here, since flate is slow - // and allocates a lot of memory by itself. + if len(resp.Header.peek(strContentEncoding)) > 0 { + // It looks like the body is already compressed. + // Do not compress it again. + return nil + } + + if !resp.Header.isCompressibleContentType() { + // The content-type cannot be compressed. + return nil + } + if resp.bodyStream != nil { + // Reset Content-Length to -1, since it is impossible + // to determine body size beforehand of streamed compression. + // For https://github.com/valyala/fasthttp/issues/176 . + resp.Header.SetContentLength(-1) + + // Do not care about memory allocations here, since flate is slow + // and allocates a lot of memory by itself. bs := resp.bodyStream resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) { - zw := acquireFlateWriter(sw, level) - copyZeroAlloc(zw, bs) - releaseFlateWriter(zw) + zw := acquireStacklessDeflateWriter(sw, level) + fw := &flushWriter{ + wf: zw, + bw: sw, + } + copyZeroAlloc(fw, bs) //nolint:errcheck + releaseStacklessDeflateWriter(zw, level) if bsc, ok := bs.(io.Closer); ok { bsc.Close() } }) } else { - w := responseBodyPool.Acquire() - zw := acquireFlateWriter(w, level) - _, err := zw.Write(resp.bodyBytes()) - releaseFlateWriter(zw) - if err != nil { - return err + bodyBytes := resp.bodyBytes() + if len(bodyBytes) < minCompressLen { + // There is no sense in spending CPU time on small body compression, + // since there is a very high probability that the compressed + // body size will be bigger than the original body size. + return nil } + w := responseBodyPool.Get() + w.B = AppendDeflateBytesLevel(w.B, bodyBytes, level) // Hack: swap resp.body with w. - responseBodyPool.Release(resp.body) + if resp.body != nil { + responseBodyPool.Put(resp.body) + } resp.body = w + resp.bodyRaw = nil } resp.Header.SetCanonical(strContentEncoding, strDeflate) return nil } +// Bodies with sizes smaller than minCompressLen aren't compressed at all +const minCompressLen = 200 + +type writeFlusher interface { + io.Writer + Flush() error +} + +type flushWriter struct { + wf writeFlusher + bw *bufio.Writer +} + +func (w *flushWriter) Write(p []byte) (int, error) { + n, err := w.wf.Write(p) + if err != nil { + return 0, err + } + if err = w.wf.Flush(); err != nil { + return 0, err + } + if err = w.bw.Flush(); err != nil { + return 0, err + } + return n, nil +} + // Write writes response to w. // // Write doesn't flush response to w for performance reasons. @@ -1236,8 +1700,19 @@ return err } -func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) error { - var err error +// ErrBodyStreamWritePanic is returned when panic happens during writing body stream. +type ErrBodyStreamWritePanic struct { + error +} + +func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error) { + defer func() { + if r := recover(); r != nil { + err = &ErrBodyStreamWritePanic{ + error: fmt.Errorf("panic while writing body stream: %+v", r), + } + } + }() contentLength := resp.Header.ContentLength() if contentLength < 0 { @@ -1253,13 +1728,23 @@ } } if contentLength >= 0 { - if err = resp.Header.Write(w); err == nil && sendBody { - err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength)) + if err = resp.Header.Write(w); err == nil { + if resp.ImmediateHeaderFlush { + err = w.Flush() + } + if err == nil && sendBody { + err = writeBodyFixedSize(w, resp.bodyStream, int64(contentLength)) + } } } else { resp.Header.SetContentLength(-1) - if err = resp.Header.Write(w); err == nil && sendBody { - err = writeBodyChunked(w, resp.bodyStream) + if err = resp.Header.Write(w); err == nil { + if resp.ImmediateHeaderFlush { + err = w.Flush() + } + if err == nil && sendBody { + err = writeBodyChunked(w, resp.bodyStream) + } } } err1 := resp.closeBodyStream() @@ -1312,7 +1797,7 @@ } func getHTTPString(hw httpWriter) string { - w := AcquireByteBuffer() + w := bytebufferpool.Get() bw := bufio.NewWriter(w) if err := hw.Write(bw); err != nil { return err.Error() @@ -1321,7 +1806,7 @@ return err.Error() } s := string(w.B) - ReleaseByteBuffer(w) + bytebufferpool.Put(w) return s } @@ -1375,19 +1860,8 @@ } } - // Unwrap a single limited reader for triggering sendfile path - // in net.TCPConn.ReadFrom. - lr, ok := r.(*io.LimitedReader) - if ok { - r = lr.R - } - n, err := copyZeroAlloc(w, r) - if ok { - lr.N -= n - } - if n != size && err == nil { err = fmt.Errorf("copied %d bytes from body stream instead of %d bytes", n, size) } @@ -1410,9 +1884,15 @@ func writeChunk(w *bufio.Writer, b []byte) error { n := len(b) - writeHexInt(w, n) - w.Write(strCRLF) - w.Write(b) + if err := writeHexInt(w, n); err != nil { + return err + } + if _, err := w.Write(strCRLF); err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } _, err := w.Write(strCRLF) err1 := w.Flush() if err == nil { @@ -1439,6 +1919,39 @@ return readBodyIdentity(r, maxBodySize, dst) } +var errChunkedStream = errors.New("chunked stream") + +func readBodyWithStreaming(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (b []byte, err error) { + if contentLength == -1 { + // handled in requestStream.Read() + return b, errChunkedStream + } + + dst = dst[:0] + + readN := maxBodySize + if readN > contentLength { + readN = contentLength + } + if readN > 8*1024 { + readN = 8 * 1024 + } + + if contentLength >= 0 && maxBodySize >= contentLength { + b, err = appendBodyFixedSize(r, dst, readN) + } else { + b, err = readBodyIdentity(r, readN, dst) + } + + if err != nil { + return b, err + } + if contentLength > maxBodySize { + return b, ErrBodyTooLarge + } + return b, nil +} + func readBodyIdentity(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, error) { dst = dst[:cap(dst)] if len(dst) == 0 { @@ -1504,6 +2017,11 @@ } } +// ErrBrokenChunk is returned when server receives a broken chunked body (Transfer-Encoding: chunked). +type ErrBrokenChunk struct { + error +} + func readBodyChunked(r *bufio.Reader, maxBodySize int, dst []byte) ([]byte, error) { if len(dst) > 0 { panic("BUG: expected zero-length buffer") @@ -1523,7 +2041,9 @@ return dst, err } if !bytes.Equal(dst[len(dst)-strCRLFLen:], strCRLF) { - return dst, fmt.Errorf("cannot find crlf at the end of chunk") + return dst, ErrBrokenChunk{ + error: fmt.Errorf("cannot find crlf at the end of chunk"), + } } dst = dst[:len(dst)-strCRLFLen] if chunkSize == 0 { @@ -1537,32 +2057,59 @@ if err != nil { return -1, err } - c, err := r.ReadByte() - if err != nil { - return -1, fmt.Errorf("cannot read '\r' char at the end of chunk size: %s", err) - } - if c != '\r' { - return -1, fmt.Errorf("unexpected char %q at the end of chunk size. Expected %q", c, '\r') + for { + c, err := r.ReadByte() + if err != nil { + return -1, ErrBrokenChunk{ + error: fmt.Errorf("cannot read '\r' char at the end of chunk size: %s", err), + } + } + // Skip any trailing whitespace after chunk size. + if c == ' ' { + continue + } + if err := r.UnreadByte(); err != nil { + return -1, ErrBrokenChunk{ + error: fmt.Errorf("cannot unread '\r' char at the end of chunk size: %s", err), + } + } + break } - c, err = r.ReadByte() + err = readCrLf(r) if err != nil { - return -1, fmt.Errorf("cannot read '\n' char at the end of chunk size: %s", err) - } - if c != '\n' { - return -1, fmt.Errorf("unexpected char %q at the end of chunk size. Expected %q", c, '\n') + return -1, err } return n, nil } +func readCrLf(r *bufio.Reader) error { + for _, exp := range []byte{'\r', '\n'} { + c, err := r.ReadByte() + if err != nil { + return ErrBrokenChunk{ + error: fmt.Errorf("cannot read %q char at the end of chunk size: %s", exp, err), + } + } + if c != exp { + return ErrBrokenChunk{ + error: fmt.Errorf("unexpected char %q at the end of chunk size. Expected %q", c, exp), + } + } + } + return nil +} + func round2(n int) int { if n <= 0 { return 0 } - n-- - x := uint(0) - for n > 0 { - n >>= 1 - x++ - } - return 1 << x + + x := uint32(n - 1) + x |= x >> 1 + x |= x >> 2 + x |= x >> 4 + x |= x >> 8 + x |= x >> 16 + + return int(x + 1) } diff -Nru golang-github-valyala-fasthttp-20160617/http_test.go golang-github-valyala-fasthttp-1.31.0/http_test.go --- golang-github-valyala-fasthttp-20160617/http_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/http_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -7,11 +7,375 @@ "io" "io/ioutil" "mime/multipart" + "reflect" + "strconv" "strings" "testing" + "time" + + "github.com/valyala/bytebufferpool" ) +func TestResponseEmptyTransferEncoding(t *testing.T) { + t.Parallel() + + var r Response + + body := "Some body" + br := bufio.NewReader(bytes.NewBufferString("HTTP/1.1 200 OK\r\nContent-Type: aaa\r\nTransfer-Encoding: \r\nContent-Length: 9\r\n\r\n" + body)) + err := r.Read(br) + if err != nil { + t.Fatal(err) + } + if got := string(r.Body()); got != body { + t.Fatalf("expected %q got %q", body, got) + } +} + +// Don't send the fragment/hash/# part of a URL to the server. +func TestFragmentInURIRequest(t *testing.T) { + t.Parallel() + + var req Request + req.SetRequestURI("https://docs.gitlab.com/ee/user/project/integrations/webhooks.html#events") + + var b bytes.Buffer + req.WriteTo(&b) //nolint:errcheck + got := b.String() + expected := "GET /ee/user/project/integrations/webhooks.html HTTP/1.1\r\nHost: docs.gitlab.com\r\n\r\n" + + if got != expected { + t.Errorf("got %q expected %q", got, expected) + } +} + +func TestIssue875(t *testing.T) { + t.Parallel() + + type testcase struct { + uri string + expectedRedirect string + expectedLocation string + } + + var testcases = []testcase{ + { + uri: `http://localhost:3000/?redirect=foo%0d%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`, + expectedRedirect: "foo\r\nSet-Cookie: SESSIONID=MaliciousValue\r\n", + expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue", + }, + { + uri: `http://localhost:3000/?redirect=foo%0dSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`, + expectedRedirect: "foo\rSet-Cookie: SESSIONID=MaliciousValue\r\n", + expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue", + }, + { + uri: `http://localhost:3000/?redirect=foo%0aSet-Cookie:%20SESSIONID=MaliciousValue%0d%0a`, + expectedRedirect: "foo\nSet-Cookie: SESSIONID=MaliciousValue\r\n", + expectedLocation: "Location: foo Set-Cookie: SESSIONID=MaliciousValue", + }, + } + + for i, tcase := range testcases { + caseName := strconv.FormatInt(int64(i), 10) + t.Run(caseName, func(subT *testing.T) { + ctx := &RequestCtx{ + Request: Request{}, + Response: Response{}, + } + ctx.Request.SetRequestURI(tcase.uri) + + q := string(ctx.QueryArgs().Peek("redirect")) + if q != tcase.expectedRedirect { + subT.Errorf("unexpected redirect query value, got: %+v", q) + } + ctx.Response.Header.Set("Location", q) + + if !strings.Contains(ctx.Response.String(), tcase.expectedLocation) { + subT.Errorf("invalid escaping, got\n%s", ctx.Response.String()) + } + }) + } +} + +func TestRequestCopyTo(t *testing.T) { + t.Parallel() + + var req Request + + // empty copy + testRequestCopyTo(t, &req) + + // init + expectedContentType := "application/x-www-form-urlencoded; charset=UTF-8" + expectedHost := "test.com" + expectedBody := "0123=56789" + s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: %s\r\nContent-Type: %s\r\nContent-Length: %d\r\n\r\n%s", + expectedHost, expectedContentType, len(expectedBody), expectedBody) + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := req.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + testRequestCopyTo(t, &req) + +} + +func TestResponseCopyTo(t *testing.T) { + t.Parallel() + + var resp Response + + // empty copy + testResponseCopyTo(t, &resp) + + // init resp + resp.laddr = zeroTCPAddr + resp.SkipBody = true + resp.Header.SetStatusCode(200) + resp.SetBodyString("test") + testResponseCopyTo(t, &resp) + +} + +func testRequestCopyTo(t *testing.T, src *Request) { + var dst Request + src.CopyTo(&dst) + + if !reflect.DeepEqual(*src, dst) { //nolint:govet + t.Fatalf("RequestCopyTo fail, src: \n%+v\ndst: \n%+v\n", *src, dst) //nolint:govet + } +} + +func testResponseCopyTo(t *testing.T, src *Response) { + var dst Response + src.CopyTo(&dst) + + if !reflect.DeepEqual(*src, dst) { //nolint:govet + t.Fatalf("ResponseCopyTo fail, src: \n%+v\ndst: \n%+v\n", *src, dst) //nolint:govet + } +} + +func TestResponseBodyStreamDeflate(t *testing.T) { + t.Parallel() + + body := createFixedBody(1e5) + + // Verifies https://github.com/valyala/fasthttp/issues/176 + // when Content-Length is explicitly set. + testResponseBodyStreamDeflate(t, body, len(body)) + + // Verifies that 'transfer-encoding: chunked' works as expected. + testResponseBodyStreamDeflate(t, body, -1) +} + +func TestResponseBodyStreamGzip(t *testing.T) { + t.Parallel() + + body := createFixedBody(1e5) + + // Verifies https://github.com/valyala/fasthttp/issues/176 + // when Content-Length is explicitly set. + testResponseBodyStreamGzip(t, body, len(body)) + + // Verifies that 'transfer-encoding: chunked' works as expected. + testResponseBodyStreamGzip(t, body, -1) +} + +func testResponseBodyStreamDeflate(t *testing.T, body []byte, bodySize int) { + var r Response + r.SetBodyStream(bytes.NewReader(body), bodySize) + + w := &bytes.Buffer{} + bw := bufio.NewWriter(w) + if err := r.WriteDeflate(bw); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := bw.Flush(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var resp Response + br := bufio.NewReader(w) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + respBody, err := resp.BodyInflate() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !bytes.Equal(respBody, body) { + t.Fatalf("unexpected body: %q. Expecting %q", respBody, body) + } +} + +func testResponseBodyStreamGzip(t *testing.T, body []byte, bodySize int) { + var r Response + r.SetBodyStream(bytes.NewReader(body), bodySize) + + w := &bytes.Buffer{} + bw := bufio.NewWriter(w) + if err := r.WriteGzip(bw); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := bw.Flush(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var resp Response + br := bufio.NewReader(w) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + respBody, err := resp.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if !bytes.Equal(respBody, body) { + t.Fatalf("unexpected body: %q. Expecting %q", respBody, body) + } +} + +func TestResponseWriteGzipNilBody(t *testing.T) { + t.Parallel() + + var r Response + w := &bytes.Buffer{} + bw := bufio.NewWriter(w) + if err := r.WriteGzip(bw); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := bw.Flush(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestResponseWriteDeflateNilBody(t *testing.T) { + t.Parallel() + + var r Response + w := &bytes.Buffer{} + bw := bufio.NewWriter(w) + if err := r.WriteDeflate(bw); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if err := bw.Flush(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestResponseSwapBodySerial(t *testing.T) { + t.Parallel() + + testResponseSwapBody(t) +} + +func TestResponseSwapBodyConcurrent(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + for i := 0; i < 10; i++ { + go func() { + testResponseSwapBody(t) + ch <- struct{}{} + }() + } + + for i := 0; i < 10; i++ { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + } +} + +func testResponseSwapBody(t *testing.T) { + var b []byte + r := AcquireResponse() + for i := 0; i < 20; i++ { + bOrig := r.Body() + b = r.SwapBody(b) + if !bytes.Equal(bOrig, b) { + t.Fatalf("unexpected body returned: %q. Expecting %q", b, bOrig) + } + r.AppendBodyString("foobar") + } + + s := "aaaabbbbcccc" + b = b[:0] + for i := 0; i < 10; i++ { + r.SetBodyStream(bytes.NewBufferString(s), len(s)) + b = r.SwapBody(b) + if string(b) != s { + t.Fatalf("unexpected body returned: %q. Expecting %q", b, s) + } + b = r.SwapBody(b) + if len(b) > 0 { + t.Fatalf("unexpected body with non-zero size returned: %q", b) + } + } + ReleaseResponse(r) +} + +func TestRequestSwapBodySerial(t *testing.T) { + t.Parallel() + + testRequestSwapBody(t) +} + +func TestRequestSwapBodyConcurrent(t *testing.T) { + t.Parallel() + + ch := make(chan struct{}) + for i := 0; i < 10; i++ { + go func() { + testRequestSwapBody(t) + ch <- struct{}{} + }() + } + + for i := 0; i < 10; i++ { + select { + case <-ch: + case <-time.After(time.Second): + t.Fatalf("timeout") + } + } +} + +func testRequestSwapBody(t *testing.T) { + var b []byte + r := AcquireRequest() + for i := 0; i < 20; i++ { + bOrig := r.Body() + b = r.SwapBody(b) + if !bytes.Equal(bOrig, b) { + t.Fatalf("unexpected body returned: %q. Expecting %q", b, bOrig) + } + r.AppendBodyString("foobar") + } + + s := "aaaabbbbcccc" + b = b[:0] + for i := 0; i < 10; i++ { + r.SetBodyStream(bytes.NewBufferString(s), len(s)) + b = r.SwapBody(b) + if string(b) != s { + t.Fatalf("unexpected body returned: %q. Expecting %q", b, s) + } + b = r.SwapBody(b) + if len(b) > 0 { + t.Fatalf("unexpected body with non-zero size returned: %q", b) + } + } + ReleaseRequest(r) +} + func TestRequestHostFromRequestURI(t *testing.T) { + t.Parallel() + hExpected := "foobar.com" var req Request req.SetRequestURI("http://proxy-host:123/foobar?baz") @@ -23,6 +387,8 @@ } func TestRequestHostFromHeader(t *testing.T) { + t.Parallel() + hExpected := "foobar.com" var req Request req.Header.SetHost(hExpected) @@ -33,6 +399,8 @@ } func TestRequestContentTypeWithCharsetIssue100(t *testing.T) { + t.Parallel() + expectedContentType := "application/x-www-form-urlencoded; charset=UTF-8" expectedBody := "0123=56789" s := fmt.Sprintf("POST / HTTP/1.1\r\nContent-Type: %s\r\nContent-Length: %d\r\n\r\n%s", @@ -63,6 +431,8 @@ } func TestRequestReadMultipartFormWithFile(t *testing.T) { + t.Parallel() + s := `POST /upload HTTP/1.1 Host: localhost:10000 Content-Length: 521 @@ -146,6 +516,8 @@ } func TestRequestRequestURI(t *testing.T) { + t.Parallel() + var r Request // Set request uri via SetRequestURI() @@ -174,6 +546,8 @@ } func TestRequestUpdateURI(t *testing.T) { + t.Parallel() + var r Request r.Header.SetHost("aaa.bbb") r.SetRequestURI("/lkjkl/kjl") @@ -190,12 +564,14 @@ if !strings.HasPrefix(s, "GET /123/432.html?aaa=bcse") { t.Fatalf("cannot find %q in %q", "GET /123/432.html?aaa=bcse", s) } - if strings.Index(s, "\r\nHost: foobar.com\r\n") < 0 { + if !strings.Contains(s, "\r\nHost: foobar.com\r\n") { t.Fatalf("cannot find %q in %q", "\r\nHost: foobar.com\r\n", s) } } func TestRequestBodyStreamMultipleBodyCalls(t *testing.T) { + t.Parallel() + var r Request s := "foobar baz abc" @@ -215,6 +591,8 @@ } func TestResponseBodyStreamMultipleBodyCalls(t *testing.T) { + t.Parallel() + var r Response s := "foobar baz abc" @@ -234,6 +612,8 @@ } func TestRequestBodyWriteToPlain(t *testing.T) { + t.Parallel() + var r Request expectedS := "foobarbaz" @@ -243,6 +623,8 @@ } func TestResponseBodyWriteToPlain(t *testing.T) { + t.Parallel() + var r Response expectedS := "foobarbaz" @@ -252,6 +634,8 @@ } func TestResponseBodyWriteToStream(t *testing.T) { + t.Parallel() + var r Response expectedS := "aaabbbccc" @@ -268,6 +652,8 @@ } func TestRequestBodyWriteToMultipart(t *testing.T) { + t.Parallel() + expectedS := "--foobar\r\nContent-Disposition: form-data; name=\"key_0\"\r\n\r\nvalue_0\r\n--foobar--\r\n" s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=foobar\r\nContent-Length: %d\r\n\r\n%s", len(expectedS), expectedS) @@ -287,7 +673,7 @@ } func testBodyWriteTo(t *testing.T, bw bodyWriterTo, expectedS string, isRetainedBody bool) { - var buf ByteBuffer + var buf bytebufferpool.ByteBuffer if err := bw.BodyWriteTo(&buf); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -310,6 +696,8 @@ } func TestRequestReadEOF(t *testing.T) { + t.Parallel() + var r Request br := bufio.NewReader(&bytes.Buffer{}) @@ -333,6 +721,8 @@ } func TestResponseReadEOF(t *testing.T) { + t.Parallel() + var r Response br := bufio.NewReader(&bytes.Buffer{}) @@ -355,13 +745,32 @@ } } +func TestRequestReadNoBody(t *testing.T) { + t.Parallel() + + var r Request + + br := bufio.NewReader(bytes.NewBufferString("GET / HTTP/1.1\r\n\r\n")) + err := r.Read(br) + r.SetHost("foobar") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + s := r.String() + if strings.Contains(s, "Content-Length: ") { + t.Fatalf("unexpected Content-Length") + } +} + func TestResponseWriteTo(t *testing.T) { + t.Parallel() + var r Response r.SetBodyString("foobar") s := r.String() - var buf ByteBuffer + var buf bytebufferpool.ByteBuffer n, err := r.WriteTo(&buf) if err != nil { t.Fatalf("unexpected error: %s", err) @@ -375,12 +784,14 @@ } func TestRequestWriteTo(t *testing.T) { + t.Parallel() + var r Request r.SetRequestURI("http://foobar.com/aaa/bbb") s := r.String() - var buf ByteBuffer + var buf bytebufferpool.ByteBuffer n, err := r.WriteTo(&buf) if err != nil { t.Fatalf("unexpected error: %s", err) @@ -394,6 +805,8 @@ } func TestResponseSkipBody(t *testing.T) { + t.Parallel() + var r Response // set StatusNotModified @@ -441,9 +854,11 @@ } func TestRequestNoContentLength(t *testing.T) { + t.Parallel() + var r Request - r.Header.SetMethod("HEAD") + r.Header.SetMethod(MethodHead) r.Header.SetHost("foobar") s := r.String() @@ -451,7 +866,7 @@ t.Fatalf("unexpected content-length in HEAD request %q", s) } - r.Header.SetMethod("POST") + r.Header.SetMethod(MethodPost) fmt.Fprintf(r.BodyWriter(), "foobar body") s = r.String() if !strings.Contains(s, "Content-Length: ") { @@ -460,6 +875,8 @@ } func TestRequestReadGzippedBody(t *testing.T) { + t.Parallel() + var r Request bodyOriginal := "foo bar baz compress me better!" @@ -471,8 +888,8 @@ t.Fatalf("unexpected error: %s", err) } - if string(r.Header.Peek("Content-Encoding")) != "gzip" { - t.Fatalf("unexpected content-encoding: %q. Expecting %q", r.Header.Peek("Content-Encoding"), "gzip") + if string(r.Header.Peek(HeaderContentEncoding)) != "gzip" { + t.Fatalf("unexpected content-encoding: %q. Expecting %q", r.Header.Peek(HeaderContentEncoding), "gzip") } if r.Header.ContentLength() != len(body) { t.Fatalf("unexpected content-length: %d. Expecting %d", r.Header.ContentLength(), len(body)) @@ -491,6 +908,8 @@ } func TestRequestReadPostNoBody(t *testing.T) { + t.Parallel() + var r Request s := "POST /foo/bar HTTP/1.1\r\nContent-Type: aaa/bbb\r\n\r\naaaa" @@ -522,6 +941,8 @@ } func TestRequestContinueReadBody(t *testing.T) { + t.Parallel() + s := "PUT /foo/bar HTTP/1.1\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" br := bufio.NewReader(bytes.NewBufferString(s)) @@ -533,7 +954,7 @@ t.Fatalf("MayContinue must return true") } - if err := r.ContinueReadBody(br, 0); err != nil { + if err := r.ContinueReadBody(br, 0, true); err != nil { t.Fatalf("error when reading request body: %s", err) } body := r.Body() @@ -550,7 +971,51 @@ } } +func TestRequestContinueReadBodyDisablePrereadMultipartForm(t *testing.T) { + t.Parallel() + + var w bytes.Buffer + mw := multipart.NewWriter(&w) + for i := 0; i < 10; i++ { + k := fmt.Sprintf("key_%d", i) + v := fmt.Sprintf("value_%d", i) + if err := mw.WriteField(k, v); err != nil { + t.Fatalf("unexpected error: %s", err) + } + } + boundary := mw.Boundary() + if err := mw.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + formData := w.Bytes() + + s := fmt.Sprintf("POST / HTTP/1.1\r\nHost: aaa\r\nContent-Type: multipart/form-data; boundary=%s\r\nContent-Length: %d\r\n\r\n%s", + boundary, len(formData), formData) + br := bufio.NewReader(bytes.NewBufferString(s)) + + var r Request + + if err := r.Header.Read(br); err != nil { + t.Fatalf("unexpected error reading headers: %s", err) + } + + if err := r.readLimitBody(br, 10000, false, false); err != nil { + t.Fatalf("unexpected error reading body: %s", err) + } + + if r.multipartForm != nil { + t.Fatalf("The multipartForm of the Request must be nil") + } + + if string(formData) != string(r.Body()) { + t.Fatalf("The body given must equal the body in the Request") + } + +} + func TestRequestMayContinue(t *testing.T) { + t.Parallel() + var r Request if r.MayContinue() { t.Fatalf("MayContinue on empty request must return false") @@ -568,6 +1033,8 @@ } func TestResponseGzipStream(t *testing.T) { + t.Parallel() + var r Response if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") @@ -575,9 +1042,11 @@ r.SetBodyStreamWriter(func(w *bufio.Writer) { fmt.Fprintf(w, "foo") w.Flush() - w.Write([]byte("barbaz")) - w.Flush() - fmt.Fprintf(w, "1234") + time.Sleep(time.Millisecond) + w.Write([]byte("barbaz")) //nolint:errcheck + w.Flush() //nolint:errcheck + time.Sleep(time.Millisecond) + fmt.Fprintf(w, "1234") //nolint:errcheck if err := w.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -589,16 +1058,18 @@ } func TestResponseDeflateStream(t *testing.T) { + t.Parallel() + var r Response if r.IsBodyStream() { t.Fatalf("IsBodyStream must return false") } r.SetBodyStreamWriter(func(w *bufio.Writer) { - w.Write([]byte("foo")) - w.Flush() - fmt.Fprintf(w, "barbaz") - w.Flush() - w.Write([]byte("1234")) + w.Write([]byte("foo")) //nolint:errcheck + w.Flush() //nolint:errcheck + fmt.Fprintf(w, "barbaz") //nolint:errcheck + w.Flush() //nolint:errcheck + w.Write([]byte("1234")) //nolint:errcheck if err := w.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -610,45 +1081,68 @@ } func TestResponseDeflate(t *testing.T) { - testResponseDeflate(t, "") - testResponseDeflate(t, "abdasdfsdaa") - testResponseDeflate(t, "asoiowqoieroqweiruqwoierqo") + t.Parallel() + + for _, s := range compressTestcases { + testResponseDeflate(t, s) + } } func TestResponseGzip(t *testing.T) { - testResponseGzip(t, "") - testResponseGzip(t, "foobarbaz") - testResponseGzip(t, "abasdwqpweoweporweprowepr") + t.Parallel() + + for _, s := range compressTestcases { + testResponseGzip(t, s) + } } func testResponseDeflate(t *testing.T, s string) { var r Response r.SetBodyString(s) testResponseDeflateExt(t, &r, s) + + // make sure the uncompressible Content-Type isn't compressed + r.Reset() + r.Header.SetContentType("image/jpeg") + r.SetBodyString(s) + testResponseDeflateExt(t, &r, s) } func testResponseDeflateExt(t *testing.T, r *Response, s string) { + isCompressible := isCompressibleResponse(r, s) + var buf bytes.Buffer + var err error bw := bufio.NewWriter(&buf) - if err := r.WriteDeflate(bw); err != nil { + if err = r.WriteDeflate(bw); err != nil { t.Fatalf("unexpected error: %s", err) } - if err := bw.Flush(); err != nil { + if err = bw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var r1 Response br := bufio.NewReader(&buf) - if err := r1.Read(br); err != nil { + if err = r1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } - ce := r1.Header.Peek("Content-Encoding") - if string(ce) != "deflate" { - t.Fatalf("unexpected Content-Encoding %q. Expecting %q", ce, "deflate") - } - body, err := r1.BodyInflate() - if err != nil { - t.Fatalf("unexpected error: %s", err) + + ce := r1.Header.Peek(HeaderContentEncoding) + var body []byte + if isCompressible { + if string(ce) != "deflate" { + t.Fatalf("unexpected Content-Encoding %q. Expecting %q. len(s)=%d, Content-Type: %q", + ce, "deflate", len(s), r.Header.ContentType()) + } + body, err = r1.BodyInflate() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + } else { + if len(ce) > 0 { + t.Fatalf("expecting empty Content-Encoding. Got %q", ce) + } + body = r1.Body() } if string(body) != s { t.Fatalf("unexpected body %q. Expecting %q", body, s) @@ -659,37 +1153,66 @@ var r Response r.SetBodyString(s) testResponseGzipExt(t, &r, s) + + // make sure the uncompressible Content-Type isn't compressed + r.Reset() + r.Header.SetContentType("image/jpeg") + r.SetBodyString(s) + testResponseGzipExt(t, &r, s) } func testResponseGzipExt(t *testing.T, r *Response, s string) { + isCompressible := isCompressibleResponse(r, s) + var buf bytes.Buffer + var err error bw := bufio.NewWriter(&buf) - if err := r.WriteGzip(bw); err != nil { + if err = r.WriteGzip(bw); err != nil { t.Fatalf("unexpected error: %s", err) } - if err := bw.Flush(); err != nil { + if err = bw.Flush(); err != nil { t.Fatalf("unexpected error: %s", err) } var r1 Response br := bufio.NewReader(&buf) - if err := r1.Read(br); err != nil { + if err = r1.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } - ce := r1.Header.Peek("Content-Encoding") - if string(ce) != "gzip" { - t.Fatalf("unexpected Content-Encoding %q. Expecting %q", ce, "gzip") - } - body, err := r1.BodyGunzip() - if err != nil { - t.Fatalf("unexpected error: %s", err) + + ce := r1.Header.Peek(HeaderContentEncoding) + var body []byte + if isCompressible { + if string(ce) != "gzip" { + t.Fatalf("unexpected Content-Encoding %q. Expecting %q. len(s)=%d, Content-Type: %q", + ce, "gzip", len(s), r.Header.ContentType()) + } + body, err = r1.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + } else { + if len(ce) > 0 { + t.Fatalf("Expecting empty Content-Encoding. Got %q", ce) + } + body = r1.Body() } if string(body) != s { t.Fatalf("unexpected body %q. Expecting %q", body, s) } } +func isCompressibleResponse(r *Response, s string) bool { + isCompressible := r.Header.isCompressibleContentType() + if isCompressible && len(s) < minCompressLen && !r.IsBodyStream() { + isCompressible = false + } + return isCompressible +} + func TestRequestMultipartForm(t *testing.T) { + t.Parallel() + var w bytes.Buffer mw := multipart.NewWriter(&w) for i := 0; i < 10; i++ { @@ -773,6 +1296,8 @@ } func TestResponseReadLimitBody(t *testing.T) { + t.Parallel() + // response with content-length testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 10) testResponseReadLimitBodySuccess(t, "HTTP/1.1 200 OK\r\nContent-Type: aa\r\nContent-Length: 10\r\n\r\n9876543210", 100) @@ -790,6 +1315,8 @@ } func TestRequestReadLimitBody(t *testing.T) { + t.Parallel() + // request with content-length testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 9) testRequestReadLimitBodySuccess(t, "POST /foo HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 9\r\nContent-Type: aaa\r\n\r\n123456789", 92) @@ -846,10 +1373,12 @@ } func TestRequestString(t *testing.T) { + t.Parallel() + var r Request r.SetRequestURI("http://foobar.com/aaa") s := r.String() - expectedS := "GET /aaa HTTP/1.1\r\nUser-Agent: fasthttp\r\nHost: foobar.com\r\n\r\n" + expectedS := "GET /aaa HTTP/1.1\r\nHost: foobar.com\r\n\r\n" if s != expectedS { t.Fatalf("unexpected request: %q. Expecting %q", s, expectedS) } @@ -867,6 +1396,8 @@ } func TestResponseBodyWriter(t *testing.T) { + t.Parallel() + var r Response w := r.BodyWriter() for i := 0; i < 10; i++ { @@ -878,6 +1409,8 @@ } func TestRequestWriteRequestURINoHost(t *testing.T) { + t.Parallel() + var req Request req.Header.SetRequestURI("http://google.com/foo/bar?baz=aaa") var w bytes.Buffer @@ -912,18 +1445,24 @@ } func TestSetRequestBodyStreamFixedSize(t *testing.T) { + t.Parallel() + testSetRequestBodyStream(t, "a", false) testSetRequestBodyStream(t, string(createFixedBody(4097)), false) testSetRequestBodyStream(t, string(createFixedBody(100500)), false) } func TestSetResponseBodyStreamFixedSize(t *testing.T) { + t.Parallel() + testSetResponseBodyStream(t, "a", false) testSetResponseBodyStream(t, string(createFixedBody(4097)), false) testSetResponseBodyStream(t, string(createFixedBody(100500)), false) } func TestSetRequestBodyStreamChunked(t *testing.T) { + t.Parallel() + testSetRequestBodyStream(t, "", true) body := "foobar baz aaa bbb ccc" @@ -934,6 +1473,8 @@ } func TestSetResponseBodyStreamChunked(t *testing.T) { + t.Parallel() + testSetResponseBodyStream(t, "", true) body := "foobar baz aaa bbb ccc" @@ -946,7 +1487,7 @@ func testSetRequestBodyStream(t *testing.T, body string, chunked bool) { var req Request req.Header.SetHost("foobar.com") - req.Header.SetMethod("POST") + req.Header.SetMethod(MethodPost) bodySize := len(body) if chunked { @@ -1013,6 +1554,8 @@ } func TestRound2(t *testing.T) { + t.Parallel() + testRound2(t, 0, 0) testRound2(t, 1, 1) testRound2(t, 2, 2) @@ -1032,6 +1575,8 @@ } func TestRequestReadChunked(t *testing.T) { + t.Parallel() + var req Request s := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3\r\nabc\r\n5\r\n12345\r\n0\r\n\r\ntrail" @@ -1049,7 +1594,28 @@ verifyTrailer(t, rb, "trail") } +// See: https://github.com/erikdubbelboer/fasthttp/issues/34 +func TestRequestChunkedWhitespace(t *testing.T) { + t.Parallel() + + var req Request + + s := "POST /foo HTTP/1.1\r\nHost: google.com\r\nTransfer-Encoding: chunked\r\nContent-Type: aa/bb\r\n\r\n3 \r\nabc\r\n0\r\n\r\n" + r := bytes.NewBufferString(s) + rb := bufio.NewReader(r) + err := req.Read(rb) + if err != nil { + t.Fatalf("Unexpected error when reading chunked request: %s", err) + } + expectedBody := "abc" + if string(req.Body()) != expectedBody { + t.Fatalf("Unexpected body %q. Expected %q", req.Body(), expectedBody) + } +} + func TestResponseReadWithoutBody(t *testing.T) { + t.Parallel() + var resp Response testResponseReadWithoutBody(t, &resp, "HTTP/1.1 304 Not Modified\r\nContent-Type: aa\r\nContent-Length: 1235\r\n\r\nfoobar", false, @@ -1091,26 +1657,33 @@ } func TestRequestSuccess(t *testing.T) { + t.Parallel() + // empty method, user-agent and body - testRequestSuccess(t, "", "/foo/bar", "google.com", "", "", "GET") + testRequestSuccess(t, "", "/foo/bar", "google.com", "", "", MethodGet) // non-empty user-agent - testRequestSuccess(t, "GET", "/foo/bar", "google.com", "MSIE", "", "GET") + testRequestSuccess(t, MethodGet, "/foo/bar", "google.com", "MSIE", "", MethodGet) // non-empty method - testRequestSuccess(t, "HEAD", "/aaa", "fobar", "", "", "HEAD") + testRequestSuccess(t, MethodHead, "/aaa", "fobar", "", "", MethodHead) // POST method with body - testRequestSuccess(t, "POST", "/bbb", "aaa.com", "Chrome aaa", "post body", "POST") + testRequestSuccess(t, MethodPost, "/bbb", "aaa.com", "Chrome aaa", "post body", MethodPost) // PUT method with body - testRequestSuccess(t, "PUT", "/aa/bb", "a.com", "ome aaa", "put body", "PUT") + testRequestSuccess(t, MethodPut, "/aa/bb", "a.com", "ome aaa", "put body", MethodPut) // only host is set - testRequestSuccess(t, "", "", "gooble.com", "", "", "GET") + testRequestSuccess(t, "", "", "gooble.com", "", "", MethodGet) + + // get with body + testRequestSuccess(t, MethodGet, "/foo/bar", "aaa.com", "", "foobar", MethodGet) } func TestResponseSuccess(t *testing.T) { + t.Parallel() + // 200 response testResponseSuccess(t, 200, "test/plain", "server", "foobar", 200, "test/plain", "server") @@ -1121,7 +1694,7 @@ // response with missing server testResponseSuccess(t, 500, "aaa", "", "aaadfsd", - 500, "aaa", string(defaultServerName)) + 500, "aaa", "") // empty body testResponseSuccess(t, 200, "bbb", "qwer", "", @@ -1161,11 +1734,11 @@ if resp1.Header.ContentLength() != len(body) { t.Fatalf("Unexpected content-length: %d. Expected %d", resp1.Header.ContentLength(), len(body)) } - if string(resp1.Header.Peek("Content-Type")) != expectedContentType { - t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.Peek("Content-Type"), expectedContentType) + if string(resp1.Header.Peek(HeaderContentType)) != expectedContentType { + t.Fatalf("Unexpected content-type: %q. Expected %q", resp1.Header.Peek(HeaderContentType), expectedContentType) } - if string(resp1.Header.Peek("Server")) != expectedServerName { - t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Peek("Server"), expectedServerName) + if string(resp1.Header.Peek(HeaderServer)) != expectedServerName { + t.Fatalf("Unexpected server: %q. Expected %q", resp1.Header.Peek(HeaderServer), expectedServerName) } if !bytes.Equal(resp1.Body(), []byte(body)) { t.Fatalf("Unexpected body: %q. Expected %q", resp1.Body(), body) @@ -1173,11 +1746,10 @@ } func TestRequestWriteError(t *testing.T) { + t.Parallel() + // no host testRequestWriteError(t, "", "/foo/bar", "", "", "") - - // get with body - testRequestWriteError(t, "GET", "/foo/bar", "aaa.com", "", "foobar") } func testRequestWriteError(t *testing.T, method, requestURI, host, userAgent, body string) { @@ -1185,11 +1757,11 @@ req.Header.SetMethod(method) req.Header.SetRequestURI(requestURI) - req.Header.Set("Host", host) - req.Header.Set("User-Agent", userAgent) + req.Header.Set(HeaderHost, host) + req.Header.Set(HeaderUserAgent, userAgent) req.SetBody([]byte(body)) - w := &ByteBuffer{} + w := &bytebufferpool.ByteBuffer{} bw := bufio.NewWriter(w) err := req.Write(bw) if err == nil { @@ -1202,13 +1774,13 @@ req.Header.SetMethod(method) req.Header.SetRequestURI(requestURI) - req.Header.Set("Host", host) - req.Header.Set("User-Agent", userAgent) + req.Header.Set(HeaderHost, host) + req.Header.Set(HeaderUserAgent, userAgent) req.SetBody([]byte(body)) contentType := "foobar" - if method == "POST" { - req.Header.Set("Content-Type", contentType) + if method == MethodPost { + req.Header.Set(HeaderContentType, contentType) } w := &bytes.Buffer{} @@ -1235,25 +1807,24 @@ if string(req1.Header.RequestURI()) != requestURI { t.Fatalf("Unexpected RequestURI: %q. Expected %q", req1.Header.RequestURI(), requestURI) } - if string(req1.Header.Peek("Host")) != host { - t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Peek("Host"), host) - } - if len(userAgent) == 0 { - userAgent = string(defaultUserAgent) + if string(req1.Header.Peek(HeaderHost)) != host { + t.Fatalf("Unexpected host: %q. Expected %q", req1.Header.Peek(HeaderHost), host) } - if string(req1.Header.Peek("User-Agent")) != userAgent { - t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.Peek("User-Agent"), userAgent) + if string(req1.Header.Peek(HeaderUserAgent)) != userAgent { + t.Fatalf("Unexpected user-agent: %q. Expected %q", req1.Header.Peek(HeaderUserAgent), userAgent) } if !bytes.Equal(req1.Body(), []byte(body)) { t.Fatalf("Unexpected body: %q. Expected %q", req1.Body(), body) } - if method == "POST" && string(req1.Header.Peek("Content-Type")) != contentType { - t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.Peek("Content-Type"), contentType) + if method == MethodPost && string(req1.Header.Peek(HeaderContentType)) != contentType { + t.Fatalf("Unexpected content-type: %q. Expected %q", req1.Header.Peek(HeaderContentType), contentType) } } func TestResponseReadSuccess(t *testing.T) { + t.Parallel() + resp := &Response{} // usual response @@ -1295,6 +1866,8 @@ } func TestResponseReadError(t *testing.T) { + t.Parallel() + resp := &Response{} // empty response @@ -1340,44 +1913,70 @@ } func TestReadBodyFixedSize(t *testing.T) { - var b []byte + t.Parallel() // zero-size body - testReadBodyFixedSize(t, b, 0) + testReadBodyFixedSize(t, 0) // small-size body - testReadBodyFixedSize(t, b, 3) + testReadBodyFixedSize(t, 3) // medium-size body - testReadBodyFixedSize(t, b, 1024) + testReadBodyFixedSize(t, 1024) // large-size body - testReadBodyFixedSize(t, b, 1024*1024) + testReadBodyFixedSize(t, 1024*1024) // smaller body after big one - testReadBodyFixedSize(t, b, 34345) + testReadBodyFixedSize(t, 34345) } func TestReadBodyChunked(t *testing.T) { - var b []byte + t.Parallel() // zero-size body - testReadBodyChunked(t, b, 0) + testReadBodyChunked(t, 0) // small-size body - testReadBodyChunked(t, b, 5) + testReadBodyChunked(t, 5) // medium-size body - testReadBodyChunked(t, b, 43488) + testReadBodyChunked(t, 43488) // big body - testReadBodyChunked(t, b, 3*1024*1024) + testReadBodyChunked(t, 3*1024*1024) // smaler body after big one - testReadBodyChunked(t, b, 12343) + testReadBodyChunked(t, 12343) +} + +func TestRequestURITLS(t *testing.T) { + t.Parallel() + + uriNoScheme := "//foobar.com/baz/aa?bb=dd&dd#sdf" + requestURI := "http:" + uriNoScheme + requestURITLS := "https:" + uriNoScheme + + var req Request + + req.isTLS = true + req.SetRequestURI(requestURI) + uri := req.URI().String() + if uri != requestURITLS { + t.Fatalf("unexpected request uri: %q. Expecting %q", uri, requestURITLS) + } + + req.Reset() + req.SetRequestURI(requestURI) + uri = req.URI().String() + if uri != requestURI { + t.Fatalf("unexpected request uri: %q. Expecting %q", uri, requestURI) + } } func TestRequestURI(t *testing.T) { + t.Parallel() + host := "foobar.com" requestURI := "/aaa/bb+b%20d?ccc=ddd&qqq#1334dfds&=d" expectedPathOriginal := "/aaa/bb+b%20d" @@ -1386,7 +1985,7 @@ expectedHash := "1334dfds&=d" var req Request - req.Header.Set("Host", host) + req.Header.Set(HeaderHost, host) req.Header.SetRequestURI(requestURI) uri := req.URI() @@ -1408,6 +2007,8 @@ } func TestRequestPostArgsSuccess(t *testing.T) { + t.Parallel() + var req Request testRequestPostArgsSuccess(t, &req, "POST / HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: application/x-www-form-urlencoded\r\nContent-Length: 0\r\n\r\n", 0, "foo=", "=") @@ -1416,6 +2017,8 @@ } func TestRequestPostArgsError(t *testing.T) { + t.Parallel() + var req Request // non-post @@ -1461,7 +2064,7 @@ } } -func testReadBodyChunked(t *testing.T, b []byte, bodySize int) { +func testReadBodyChunked(t *testing.T, bodySize int) { body := createFixedBody(bodySize) chunkedBody := createChunkedBody(body) expectedTrailer := []byte("chunked shit") @@ -1479,7 +2082,7 @@ verifyTrailer(t, br, string(expectedTrailer)) } -func testReadBodyFixedSize(t *testing.T, b []byte, bodySize int) { +func testReadBodyFixedSize(t *testing.T, bodySize int) { body := createFixedBody(bodySize) expectedTrailer := []byte("traler aaaa") bodyWithTrailer := append(body, expectedTrailer...) @@ -1519,3 +2122,389 @@ } return append(b, []byte("0\r\n\r\n")...) } + +func TestWriteMultipartForm(t *testing.T) { + t.Parallel() + + var w bytes.Buffer + s := strings.Replace(`--foo +Content-Disposition: form-data; name="key" + +value +--foo +Content-Disposition: form-data; name="file"; filename="test.json" +Content-Type: application/json + +{"foo": "bar"} +--foo-- +`, "\n", "\r\n", -1) + mr := multipart.NewReader(strings.NewReader(s), "foo") + form, err := mr.ReadForm(1024) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if err := WriteMultipartForm(&w, form, "foo"); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if w.String() != s { + t.Fatalf("unexpected output %q", w.Bytes()) + } +} + +func TestResponseRawBodySet(t *testing.T) { + t.Parallel() + + var resp Response + + expectedS := "test" + body := []byte(expectedS) + resp.SetBodyRaw(body) + + testBodyWriteTo(t, &resp, expectedS, true) +} + +func TestRequestRawBodySet(t *testing.T) { + t.Parallel() + + var r Request + + expectedS := "test" + body := []byte(expectedS) + r.SetBodyRaw(body) + + testBodyWriteTo(t, &r, expectedS, true) +} + +func TestResponseRawBodyReset(t *testing.T) { + t.Parallel() + + var resp Response + + body := []byte("test") + resp.SetBodyRaw(body) + resp.ResetBody() + + testBodyWriteTo(t, &resp, "", true) +} + +func TestRequestRawBodyReset(t *testing.T) { + t.Parallel() + + var r Request + + body := []byte("test") + r.SetBodyRaw(body) + r.ResetBody() + + testBodyWriteTo(t, &r, "", true) +} + +func TestResponseRawBodyCopyTo(t *testing.T) { + t.Parallel() + + var resp Response + + expectedS := "test" + body := []byte(expectedS) + resp.SetBodyRaw(body) + + testResponseCopyTo(t, &resp) +} + +func TestRequestRawBodyCopyTo(t *testing.T) { + t.Parallel() + + var a Request + + body := []byte("test") + a.SetBodyRaw(body) + + var b Request + + a.CopyTo(&b) + + testBodyWriteTo(t, &a, "test", true) + testBodyWriteTo(t, &b, "test", true) +} + +type testReader struct { + read chan (int) + cb chan (struct{}) + onClose func() error +} + +func (r *testReader) Read(b []byte) (int, error) { + read := <-r.read + + if read == -1 { + return 0, io.EOF + } + + r.cb <- struct{}{} + + for i := 0; i < read; i++ { + b[i] = 'x' + } + + return read, nil +} + +func (r *testReader) Close() error { + if r.onClose != nil { + return r.onClose() + } + return nil +} + +func TestResponseImmediateHeaderFlushRegressionFixedLength(t *testing.T) { + t.Parallel() + + var r Response + + expectedS := "aaabbbccc" + buf := bytes.NewBufferString(expectedS) + r.SetBodyStream(buf, len(expectedS)) + r.ImmediateHeaderFlush = true + + testBodyWriteTo(t, &r, expectedS, false) +} + +func TestResponseImmediateHeaderFlushRegressionChunked(t *testing.T) { + t.Parallel() + + var r Response + + expectedS := "aaabbbccc" + buf := bytes.NewBufferString(expectedS) + r.SetBodyStream(buf, -1) + r.ImmediateHeaderFlush = true + + testBodyWriteTo(t, &r, expectedS, false) +} + +func TestResponseImmediateHeaderFlushFixedLength(t *testing.T) { + t.Parallel() + + var r Response + + r.ImmediateHeaderFlush = true + + ch := make(chan int) + cb := make(chan struct{}) + + buf := &testReader{read: ch, cb: cb} + + r.SetBodyStream(buf, 3) + + b := []byte{} + w := bytes.NewBuffer(b) + bb := bufio.NewWriter(w) + + bw := &r + + waitForIt := make(chan struct{}) + + go func() { + if err := bw.Write(bb); err != nil { + t.Errorf("unexpected error: %s", err) + } + waitForIt <- struct{}{} + }() + + ch <- 3 + + if !strings.Contains(w.String(), "Content-Length: 3") { + t.Fatalf("Expected headers to be flushed") + } + + if strings.Contains(w.String(), "xxx") { + t.Fatalf("Did not expext body to be written yet") + } + + <-cb + ch <- -1 + + <-waitForIt +} + +func TestResponseImmediateHeaderFlushFixedLengthSkipBody(t *testing.T) { + t.Parallel() + + var r Response + + r.ImmediateHeaderFlush = true + r.SkipBody = true + + ch := make(chan int) + cb := make(chan struct{}) + + buf := &testReader{read: ch, cb: cb} + + r.SetBodyStream(buf, 0) + + b := []byte{} + w := bytes.NewBuffer(b) + bb := bufio.NewWriter(w) + + var headersOnClose string + buf.onClose = func() error { + headersOnClose = w.String() + return nil + } + + bw := &r + + if err := bw.Write(bb); err != nil { + t.Errorf("unexpected error: %s", err) + } + + if !strings.Contains(headersOnClose, "Content-Length: 0") { + t.Fatalf("Expected headers to be eagerly flushed") + } +} + +func TestResponseImmediateHeaderFlushChunked(t *testing.T) { + t.Parallel() + + var r Response + + r.ImmediateHeaderFlush = true + + ch := make(chan int) + cb := make(chan struct{}) + + buf := &testReader{read: ch, cb: cb} + + r.SetBodyStream(buf, -1) + + b := []byte{} + w := bytes.NewBuffer(b) + bb := bufio.NewWriter(w) + + bw := &r + + waitForIt := make(chan struct{}) + + go func() { + if err := bw.Write(bb); err != nil { + t.Errorf("unexpected error: %s", err) + } + + waitForIt <- struct{}{} + }() + + ch <- 3 + + if !strings.Contains(w.String(), "Transfer-Encoding: chunked") { + t.Fatalf("Expected headers to be flushed") + } + + if strings.Contains(w.String(), "xxx") { + t.Fatalf("Did not expext body to be written yet") + } + + <-cb + ch <- -1 + + <-waitForIt +} + +func TestResponseImmediateHeaderFlushChunkedNoBody(t *testing.T) { + t.Parallel() + + var r Response + + r.ImmediateHeaderFlush = true + r.SkipBody = true + + ch := make(chan int) + cb := make(chan struct{}) + + buf := &testReader{read: ch, cb: cb} + + r.SetBodyStream(buf, -1) + + b := []byte{} + w := bytes.NewBuffer(b) + bb := bufio.NewWriter(w) + + var headersOnClose string + buf.onClose = func() error { + headersOnClose = w.String() + return nil + } + + bw := &r + + if err := bw.Write(bb); err != nil { + t.Errorf("unexpected error: %s", err) + } + + if !strings.Contains(headersOnClose, "Transfer-Encoding: chunked") { + t.Fatalf("Expected headers to be eagerly flushed") + } +} + +type ErroneousBodyStream struct { + errOnRead bool + errOnClose bool +} + +func (ebs *ErroneousBodyStream) Read(p []byte) (n int, err error) { + if ebs.errOnRead { + panic("reading erroneous body stream") + } + return 0, io.EOF +} + +func (ebs *ErroneousBodyStream) Close() error { + if ebs.errOnClose { + panic("closing erroneous body stream") + } + return nil +} + +func TestResponseBodyStreamErrorOnPanicDuringRead(t *testing.T) { + t.Parallel() + var resp Response + var w bytes.Buffer + bw := bufio.NewWriter(&w) + + ebs := &ErroneousBodyStream{errOnRead: true, errOnClose: false} + resp.SetBodyStream(ebs, 42) + err := resp.Write(bw) + if err == nil { + t.Fatalf("expected error when writing response.") + } + e, ok := err.(*ErrBodyStreamWritePanic) + if !ok { + t.Fatalf("expected error struct to be *ErrBodyStreamWritePanic, got: %+v.", e) + } + if e.Error() != "panic while writing body stream: reading erroneous body stream" { + t.Fatalf("unexpected error value, got: %+v.", e.Error()) + } +} + +func TestResponseBodyStreamErrorOnPanicDuringClose(t *testing.T) { + t.Parallel() + var resp Response + var w bytes.Buffer + bw := bufio.NewWriter(&w) + + ebs := &ErroneousBodyStream{errOnRead: false, errOnClose: true} + resp.SetBodyStream(ebs, 42) + err := resp.Write(bw) + if err == nil { + t.Fatalf("expected error when writing response.") + } + e, ok := err.(*ErrBodyStreamWritePanic) + if !ok { + t.Fatalf("expected error struct to be *ErrBodyStreamWritePanic, got: %+v.", e) + } + if e.Error() != "panic while writing body stream: closing erroneous body stream" { + t.Fatalf("unexpected error value, got: %+v.", e.Error()) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/lbclient_example_test.go golang-github-valyala-fasthttp-1.31.0/lbclient_example_test.go --- golang-github-valyala-fasthttp-20160617/lbclient_example_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/lbclient_example_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,42 @@ +package fasthttp_test + +import ( + "fmt" + "log" + + "github.com/valyala/fasthttp" +) + +func ExampleLBClient() { + // Requests will be spread among these servers. + servers := []string{ + "google.com:80", + "foobar.com:8080", + "127.0.0.1:123", + } + + // Prepare clients for each server + var lbc fasthttp.LBClient + for _, addr := range servers { + c := &fasthttp.HostClient{ + Addr: addr, + } + lbc.Clients = append(lbc.Clients, c) + } + + // Send requests to load-balanced servers + var req fasthttp.Request + var resp fasthttp.Response + for i := 0; i < 10; i++ { + url := fmt.Sprintf("http://abcedfg/foo/bar/%d", i) + req.SetRequestURI(url) + if err := lbc.Do(&req, &resp); err != nil { + log.Fatalf("Error when sending request: %s", err) + } + if resp.StatusCode() != fasthttp.StatusOK { + log.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), fasthttp.StatusOK) + } + + useResponseBody(resp.Body()) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/lbclient.go golang-github-valyala-fasthttp-1.31.0/lbclient.go --- golang-github-valyala-fasthttp-20160617/lbclient.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/lbclient.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,165 @@ +package fasthttp + +import ( + "sync" + "sync/atomic" + "time" +) + +// BalancingClient is the interface for clients, which may be passed +// to LBClient.Clients. +type BalancingClient interface { + DoDeadline(req *Request, resp *Response, deadline time.Time) error + PendingRequests() int +} + +// LBClient balances requests among available LBClient.Clients. +// +// It has the following features: +// +// - Balances load among available clients using 'least loaded' + 'least total' +// hybrid technique. +// - Dynamically decreases load on unhealthy clients. +// +// It is forbidden copying LBClient instances. Create new instances instead. +// +// It is safe calling LBClient methods from concurrently running goroutines. +type LBClient struct { + noCopy noCopy //nolint:unused,structcheck + + // Clients must contain non-zero clients list. + // Incoming requests are balanced among these clients. + Clients []BalancingClient + + // HealthCheck is a callback called after each request. + // + // The request, response and the error returned by the client + // is passed to HealthCheck, so the callback may determine whether + // the client is healthy. + // + // Load on the current client is decreased if HealthCheck returns false. + // + // By default HealthCheck returns false if err != nil. + HealthCheck func(req *Request, resp *Response, err error) bool + + // Timeout is the request timeout used when calling LBClient.Do. + // + // DefaultLBClientTimeout is used by default. + Timeout time.Duration + + cs []*lbClient + + once sync.Once +} + +// DefaultLBClientTimeout is the default request timeout used by LBClient +// when calling LBClient.Do. +// +// The timeout may be overridden via LBClient.Timeout. +const DefaultLBClientTimeout = time.Second + +// DoDeadline calls DoDeadline on the least loaded client +func (cc *LBClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { + return cc.get().DoDeadline(req, resp, deadline) +} + +// DoTimeout calculates deadline and calls DoDeadline on the least loaded client +func (cc *LBClient) DoTimeout(req *Request, resp *Response, timeout time.Duration) error { + deadline := time.Now().Add(timeout) + return cc.get().DoDeadline(req, resp, deadline) +} + +// Do calls calculates deadline using LBClient.Timeout and calls DoDeadline +// on the least loaded client. +func (cc *LBClient) Do(req *Request, resp *Response) error { + timeout := cc.Timeout + if timeout <= 0 { + timeout = DefaultLBClientTimeout + } + return cc.DoTimeout(req, resp, timeout) +} + +func (cc *LBClient) init() { + if len(cc.Clients) == 0 { + panic("BUG: LBClient.Clients cannot be empty") + } + for _, c := range cc.Clients { + cc.cs = append(cc.cs, &lbClient{ + c: c, + healthCheck: cc.HealthCheck, + }) + } +} + +func (cc *LBClient) get() *lbClient { + cc.once.Do(cc.init) + + cs := cc.cs + + minC := cs[0] + minN := minC.PendingRequests() + minT := atomic.LoadUint64(&minC.total) + for _, c := range cs[1:] { + n := c.PendingRequests() + t := atomic.LoadUint64(&c.total) + if n < minN || (n == minN && t < minT) { + minC = c + minN = n + minT = t + } + } + return minC +} + +type lbClient struct { + c BalancingClient + healthCheck func(req *Request, resp *Response, err error) bool + penalty uint32 + + // total amount of requests handled. + total uint64 +} + +func (c *lbClient) DoDeadline(req *Request, resp *Response, deadline time.Time) error { + err := c.c.DoDeadline(req, resp, deadline) + if !c.isHealthy(req, resp, err) && c.incPenalty() { + // Penalize the client returning error, so the next requests + // are routed to another clients. + time.AfterFunc(penaltyDuration, c.decPenalty) + } else { + atomic.AddUint64(&c.total, 1) + } + return err +} + +func (c *lbClient) PendingRequests() int { + n := c.c.PendingRequests() + m := atomic.LoadUint32(&c.penalty) + return n + int(m) +} + +func (c *lbClient) isHealthy(req *Request, resp *Response, err error) bool { + if c.healthCheck == nil { + return err == nil + } + return c.healthCheck(req, resp, err) +} + +func (c *lbClient) incPenalty() bool { + m := atomic.AddUint32(&c.penalty, 1) + if m > maxPenalty { + c.decPenalty() + return false + } + return true +} + +func (c *lbClient) decPenalty() { + atomic.AddUint32(&c.penalty, ^uint32(0)) +} + +const ( + maxPenalty = 300 + + penaltyDuration = 3 * time.Second +) diff -Nru golang-github-valyala-fasthttp-20160617/LICENSE golang-github-valyala-fasthttp-1.31.0/LICENSE --- golang-github-valyala-fasthttp-20160617/LICENSE 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/LICENSE 2021-10-09 18:39:05.000000000 +0000 @@ -1,22 +1,9 @@ The MIT License (MIT) -Copyright (c) 2015-2016 Aliaksandr Valialkin, VertaMedia +Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff -Nru golang-github-valyala-fasthttp-20160617/methods.go golang-github-valyala-fasthttp-1.31.0/methods.go --- golang-github-valyala-fasthttp-20160617/methods.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/methods.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,14 @@ +package fasthttp + +// HTTP methods were copied from net/http. +const ( + MethodGet = "GET" // RFC 7231, 4.3.1 + MethodHead = "HEAD" // RFC 7231, 4.3.2 + MethodPost = "POST" // RFC 7231, 4.3.3 + MethodPut = "PUT" // RFC 7231, 4.3.4 + MethodPatch = "PATCH" // RFC 5789 + MethodDelete = "DELETE" // RFC 7231, 4.3.5 + MethodConnect = "CONNECT" // RFC 7231, 4.3.6 + MethodOptions = "OPTIONS" // RFC 7231, 4.3.7 + MethodTrace = "TRACE" // RFC 7231, 4.3.8 +) diff -Nru golang-github-valyala-fasthttp-20160617/nocopy.go golang-github-valyala-fasthttp-1.31.0/nocopy.go --- golang-github-valyala-fasthttp-20160617/nocopy.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/nocopy.go 2021-10-09 18:39:05.000000000 +0000 @@ -4,6 +4,8 @@ // so `go vet` gives a warning if this struct is copied. // // See https://github.com/golang/go/issues/8005#issuecomment-190753527 for details. -type noCopy struct{} +// and also: https://stackoverflow.com/questions/52494458/nocopy-minimal-example +type noCopy struct{} //nolint:unused -func (*noCopy) Lock() {} +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} diff -Nru golang-github-valyala-fasthttp-20160617/peripconn_test.go golang-github-valyala-fasthttp-1.31.0/peripconn_test.go --- golang-github-valyala-fasthttp-20160617/peripconn_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/peripconn_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -5,6 +5,8 @@ ) func TestIPxUint32(t *testing.T) { + t.Parallel() + testIPxUint32(t, 0) testIPxUint32(t, 10) testIPxUint32(t, 0x12892392) @@ -19,6 +21,8 @@ } func TestPerIPConnCounter(t *testing.T) { + t.Parallel() + var cc perIPConnCounter expectPanic(t, func() { cc.Unregister(123) }) diff -Nru golang-github-valyala-fasthttp-20160617/pprofhandler/pprof.go golang-github-valyala-fasthttp-1.31.0/pprofhandler/pprof.go --- golang-github-valyala-fasthttp-20160617/pprofhandler/pprof.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/pprofhandler/pprof.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,44 @@ +package pprofhandler + +import ( + "net/http/pprof" + rtp "runtime/pprof" + "strings" + + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttpadaptor" +) + +var ( + cmdline = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Cmdline) + profile = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Profile) + symbol = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Symbol) + trace = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Trace) + index = fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Index) +) + +// PprofHandler serves server runtime profiling data in the format expected by the pprof visualization tool. +// +// See https://golang.org/pkg/net/http/pprof/ for details. +func PprofHandler(ctx *fasthttp.RequestCtx) { + ctx.Response.Header.Set("Content-Type", "text/html") + if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/cmdline") { + cmdline(ctx) + } else if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/profile") { + profile(ctx) + } else if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/symbol") { + symbol(ctx) + } else if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/trace") { + trace(ctx) + } else { + for _, v := range rtp.Profiles() { + ppName := v.Name() + if strings.HasPrefix(string(ctx.Path()), "/debug/pprof/"+ppName) { + namedHandler := fasthttpadaptor.NewFastHTTPHandlerFunc(pprof.Handler(ppName).ServeHTTP) + namedHandler(ctx) + return + } + } + index(ctx) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/prefork/prefork.go golang-github-valyala-fasthttp-1.31.0/prefork/prefork.go --- golang-github-valyala-fasthttp-20160617/prefork/prefork.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/prefork/prefork.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,283 @@ +package prefork + +import ( + "errors" + "flag" + "log" + "net" + "os" + "os/exec" + "runtime" + + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/reuseport" +) + +const ( + preforkChildFlag = "-prefork-child" + defaultNetwork = "tcp4" +) + +var ( + defaultLogger = Logger(log.New(os.Stderr, "", log.LstdFlags)) + // ErrOverRecovery is returned when the times of starting over child prefork processes exceed + // the threshold. + ErrOverRecovery = errors.New("exceeding the value of RecoverThreshold") + + // ErrOnlyReuseportOnWindows is returned when Reuseport is false. + ErrOnlyReuseportOnWindows = errors.New("windows only supports Reuseport = true") +) + +// Logger is used for logging formatted messages. +type Logger interface { + // Printf must have the same semantics as log.Printf. + Printf(format string, args ...interface{}) +} + +// Prefork implements fasthttp server prefork +// +// Preforks master process (with all cores) between several child processes +// increases performance significantly, because Go doesn't have to share +// and manage memory between cores +// +// WARNING: using prefork prevents the use of any global state! +// Things like in-memory caches won't work. +type Prefork struct { + // The network must be "tcp", "tcp4" or "tcp6". + // + // By default is "tcp4" + Network string + + // Flag to use a listener with reuseport, if not a file Listener will be used + // See: https://www.nginx.com/blog/socket-sharding-nginx-release-1-9-1/ + // + // It's disabled by default + Reuseport bool + + // Child prefork processes may exit with failure and will be started over until the times reach + // the value of RecoverThreshold, then it will return and terminate the server. + RecoverThreshold int + + // By default standard logger from log package is used. + Logger Logger + + ServeFunc func(ln net.Listener) error + ServeTLSFunc func(ln net.Listener, certFile, keyFile string) error + ServeTLSEmbedFunc func(ln net.Listener, certData, keyData []byte) error + + ln net.Listener + files []*os.File +} + +func init() { //nolint:gochecknoinits + // Definition flag to not break the program when the user adds their own flags + // and runs `flag.Parse()` + flag.Bool(preforkChildFlag[1:], false, "Is a child process") +} + +// IsChild checks if the current thread/process is a child +func IsChild() bool { + for _, arg := range os.Args[1:] { + if arg == preforkChildFlag { + return true + } + } + + return false +} + +// New wraps the fasthttp server to run with preforked processes +func New(s *fasthttp.Server) *Prefork { + return &Prefork{ + Network: defaultNetwork, + RecoverThreshold: runtime.GOMAXPROCS(0) / 2, + Logger: s.Logger, + ServeFunc: s.Serve, + ServeTLSFunc: s.ServeTLS, + ServeTLSEmbedFunc: s.ServeTLSEmbed, + } +} + +func (p *Prefork) logger() Logger { + if p.Logger != nil { + return p.Logger + } + return defaultLogger +} + +func (p *Prefork) listen(addr string) (net.Listener, error) { + runtime.GOMAXPROCS(1) + + if p.Network == "" { + p.Network = defaultNetwork + } + + if p.Reuseport { + return reuseport.Listen(p.Network, addr) + } + + return net.FileListener(os.NewFile(3, "")) +} + +func (p *Prefork) setTCPListenerFiles(addr string) error { + if p.Network == "" { + p.Network = defaultNetwork + } + + tcpAddr, err := net.ResolveTCPAddr(p.Network, addr) + if err != nil { + return err + } + + tcplistener, err := net.ListenTCP(p.Network, tcpAddr) + if err != nil { + return err + } + + p.ln = tcplistener + + fl, err := tcplistener.File() + if err != nil { + return err + } + + p.files = []*os.File{fl} + + return nil +} + +func (p *Prefork) doCommand() (*exec.Cmd, error) { + /* #nosec G204 */ + cmd := exec.Command(os.Args[0], append(os.Args[1:], preforkChildFlag)...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.ExtraFiles = p.files + return cmd, cmd.Start() +} + +func (p *Prefork) prefork(addr string) (err error) { + if !p.Reuseport { + if runtime.GOOS == "windows" { + return ErrOnlyReuseportOnWindows + } + + if err = p.setTCPListenerFiles(addr); err != nil { + return + } + + // defer for closing the net.Listener opened by setTCPListenerFiles. + defer func() { + e := p.ln.Close() + if err == nil { + err = e + } + }() + } + + type procSig struct { + pid int + err error + } + + goMaxProcs := runtime.GOMAXPROCS(0) + sigCh := make(chan procSig, goMaxProcs) + childProcs := make(map[int]*exec.Cmd) + + defer func() { + for _, proc := range childProcs { + _ = proc.Process.Kill() + } + }() + + for i := 0; i < goMaxProcs; i++ { + var cmd *exec.Cmd + if cmd, err = p.doCommand(); err != nil { + p.logger().Printf("failed to start a child prefork process, error: %v\n", err) + return + } + + childProcs[cmd.Process.Pid] = cmd + go func() { + sigCh <- procSig{cmd.Process.Pid, cmd.Wait()} + }() + } + + var exitedProcs int + for sig := range sigCh { + delete(childProcs, sig.pid) + + p.logger().Printf("one of the child prefork processes exited with "+ + "error: %v", sig.err) + + if exitedProcs++; exitedProcs > p.RecoverThreshold { + p.logger().Printf("child prefork processes exit too many times, "+ + "which exceeds the value of RecoverThreshold(%d), "+ + "exiting the master process.\n", exitedProcs) + err = ErrOverRecovery + break + } + + var cmd *exec.Cmd + if cmd, err = p.doCommand(); err != nil { + break + } + childProcs[cmd.Process.Pid] = cmd + go func() { + sigCh <- procSig{cmd.Process.Pid, cmd.Wait()} + }() + } + + return +} + +// ListenAndServe serves HTTP requests from the given TCP addr +func (p *Prefork) ListenAndServe(addr string) error { + if IsChild() { + ln, err := p.listen(addr) + if err != nil { + return err + } + + p.ln = ln + + return p.ServeFunc(ln) + } + + return p.prefork(addr) +} + +// ListenAndServeTLS serves HTTPS requests from the given TCP addr +// +// certFile and keyFile are paths to TLS certificate and key files. +func (p *Prefork) ListenAndServeTLS(addr, certKey, certFile string) error { + if IsChild() { + ln, err := p.listen(addr) + if err != nil { + return err + } + + p.ln = ln + + return p.ServeTLSFunc(ln, certFile, certKey) + } + + return p.prefork(addr) +} + +// ListenAndServeTLSEmbed serves HTTPS requests from the given TCP addr +// +// certData and keyData must contain valid TLS certificate and key data. +func (p *Prefork) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error { + if IsChild() { + ln, err := p.listen(addr) + if err != nil { + return err + } + + p.ln = ln + + return p.ServeTLSEmbedFunc(ln, certData, keyData) + } + + return p.prefork(addr) +} diff -Nru golang-github-valyala-fasthttp-20160617/prefork/prefork_test.go golang-github-valyala-fasthttp-1.31.0/prefork/prefork_test.go --- golang-github-valyala-fasthttp-20160617/prefork/prefork_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/prefork/prefork_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,228 @@ +package prefork + +import ( + "fmt" + "math/rand" + "net" + "os" + "reflect" + "runtime" + "testing" + + "github.com/valyala/fasthttp" +) + +func setUp() { + os.Args = append(os.Args, preforkChildFlag) +} + +func tearDown() { + os.Args = os.Args[:len(os.Args)-1] +} + +func getAddr() string { + return fmt.Sprintf("0.0.0.0:%d", rand.Intn(9000-3000)+3000) +} + +func Test_IsChild(t *testing.T) { + // This test can't run parallel as it modifies os.Args. + + v := IsChild() + if v { + t.Errorf("IsChild() == %v, want %v", v, false) + } + + setUp() + defer tearDown() + + v = IsChild() + if !v { + t.Errorf("IsChild() == %v, want %v", v, true) + } +} + +func Test_New(t *testing.T) { + t.Parallel() + + s := &fasthttp.Server{} + p := New(s) + + if p.Network != defaultNetwork { + t.Errorf("Prefork.Netork == %s, want %s", p.Network, defaultNetwork) + } + + if reflect.ValueOf(p.ServeFunc).Pointer() != reflect.ValueOf(s.Serve).Pointer() { + t.Errorf("Prefork.ServeFunc == %p, want %p", p.ServeFunc, s.Serve) + } + + if reflect.ValueOf(p.ServeTLSFunc).Pointer() != reflect.ValueOf(s.ServeTLS).Pointer() { + t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSFunc, s.ServeTLS) + } + + if reflect.ValueOf(p.ServeTLSEmbedFunc).Pointer() != reflect.ValueOf(s.ServeTLSEmbed).Pointer() { + t.Errorf("Prefork.ServeTLSFunc == %p, want %p", p.ServeTLSEmbedFunc, s.ServeTLSEmbed) + } +} + +func Test_listen(t *testing.T) { + t.Parallel() + + p := &Prefork{ + Reuseport: true, + } + addr := getAddr() + + ln, err := p.listen(addr) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + ln.Close() + + lnAddr := ln.Addr().String() + if lnAddr != addr { + t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr) + } + + if p.Network != defaultNetwork { + t.Errorf("Prefork.Network == %s, want %s", p.Network, defaultNetwork) + } + + procs := runtime.GOMAXPROCS(0) + if procs != 1 { + t.Errorf("GOMAXPROCS == %d, want %d", procs, 1) + } +} + +func Test_setTCPListenerFiles(t *testing.T) { + t.Parallel() + + if runtime.GOOS == "windows" { + t.SkipNow() + } + + p := &Prefork{} + addr := getAddr() + + err := p.setTCPListenerFiles(addr) + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if p.ln == nil { + t.Fatal("Prefork.ln is nil") + } + + p.ln.Close() + + lnAddr := p.ln.Addr().String() + if lnAddr != addr { + t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr) + } + + if p.Network != defaultNetwork { + t.Errorf("Prefork.Network == %s, want %s", p.Network, defaultNetwork) + } + + if len(p.files) != 1 { + t.Errorf("Prefork.files == %d, want %d", len(p.files), 1) + } +} + +func Test_ListenAndServe(t *testing.T) { + // This test can't run parallel as it modifies os.Args. + + setUp() + defer tearDown() + + s := &fasthttp.Server{} + p := New(s) + p.Reuseport = true + p.ServeFunc = func(ln net.Listener) error { + return nil + } + + addr := getAddr() + + err := p.ListenAndServe(addr) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + p.ln.Close() + + lnAddr := p.ln.Addr().String() + if lnAddr != addr { + t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr) + } + + if p.ln == nil { + t.Error("Prefork.ln is nil") + } +} + +func Test_ListenAndServeTLS(t *testing.T) { + // This test can't run parallel as it modifies os.Args. + + setUp() + defer tearDown() + + s := &fasthttp.Server{} + p := New(s) + p.Reuseport = true + p.ServeTLSFunc = func(ln net.Listener, certFile, keyFile string) error { + return nil + } + + addr := getAddr() + + err := p.ListenAndServeTLS(addr, "./key", "./cert") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + p.ln.Close() + + lnAddr := p.ln.Addr().String() + if lnAddr != addr { + t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr) + } + + if p.ln == nil { + t.Error("Prefork.ln is nil") + } +} + +func Test_ListenAndServeTLSEmbed(t *testing.T) { + // This test can't run parallel as it modifies os.Args. + + setUp() + defer tearDown() + + s := &fasthttp.Server{} + p := New(s) + p.Reuseport = true + p.ServeTLSEmbedFunc = func(ln net.Listener, certData, keyData []byte) error { + return nil + } + + addr := getAddr() + + err := p.ListenAndServeTLSEmbed(addr, []byte("key"), []byte("cert")) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + p.ln.Close() + + lnAddr := p.ln.Addr().String() + if lnAddr != addr { + t.Errorf("Prefork.Addr == %s, want %s", lnAddr, addr) + } + + if p.ln == nil { + t.Error("Prefork.ln is nil") + } +} diff -Nru golang-github-valyala-fasthttp-20160617/prefork/README.md golang-github-valyala-fasthttp-1.31.0/prefork/README.md --- golang-github-valyala-fasthttp-20160617/prefork/README.md 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/prefork/README.md 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,90 @@ +# Prefork + +Server prefork implementation. + +Preforks master process between several child processes increases performance, because Go doesn't have to share and manage memory between cores. + +**WARNING: using prefork prevents the use of any global state!. Things like in-memory caches won't work.** + +- How it works: + +```go +import ( + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/prefork" +) + +server := &fasthttp.Server{ + // Your configuration +} + +// Wraps the server with prefork +preforkServer := prefork.New(server) + +if err := preforkServer.ListenAndServe(":8080"); err != nil { + panic(err) +} +``` + +## Benchmarks + +Environment: + +- Machine: MacBook Pro 13-inch, 2017 +- OS: MacOS 10.15.3 +- Go: go1.13.6 darwin/amd64 + +Handler code: + +```go +func requestHandler(ctx *fasthttp.RequestCtx) { + // Simulates some hard work + time.Sleep(100 * time.Millisecond) +} +``` + +Test command: + +```bash +$ wrk -H 'Host: localhost' -H 'Accept: text/plain,text/html;q=0.9,application/xhtml+xml;q=0.9,application/xml;q=0.8,*/*;q=0.7' -H 'Connection: keep-alive' --latency -d 15 -c 512 --timeout 8 -t 4 http://localhost:8080 +``` + +Results: + +- prefork + +```bash +Running 15s test @ http://localhost:8080 + 4 threads and 512 connections + Thread Stats Avg Stdev Max +/- Stdev + Latency 4.75ms 4.27ms 126.24ms 97.45% + Req/Sec 26.46k 4.16k 71.18k 88.72% + Latency Distribution + 50% 4.55ms + 75% 4.82ms + 90% 5.46ms + 99% 15.49ms + 1581916 requests in 15.09s, 140.30MB read + Socket errors: connect 0, read 318, write 0, timeout 0 +Requests/sec: 104861.58 +Transfer/sec: 9.30MB +``` + +- **non**-prefork + +```bash +Running 15s test @ http://localhost:8080 + 4 threads and 512 connections + Thread Stats Avg Stdev Max +/- Stdev + Latency 6.42ms 11.83ms 177.19ms 96.42% + Req/Sec 24.96k 5.83k 56.83k 82.93% + Latency Distribution + 50% 4.53ms + 75% 4.93ms + 90% 6.94ms + 99% 74.54ms + 1472441 requests in 15.09s, 130.59MB read + Socket errors: connect 0, read 265, write 0, timeout 0 +Requests/sec: 97553.34 +Transfer/sec: 8.65MB +``` diff -Nru golang-github-valyala-fasthttp-20160617/README.md golang-github-valyala-fasthttp-1.31.0/README.md --- golang-github-valyala-fasthttp-20160617/README.md 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/README.md 2021-10-09 18:39:05.000000000 +0000 @@ -1,15 +1,20 @@ -[![Build Status](https://travis-ci.org/valyala/fasthttp.svg)](https://travis-ci.org/valyala/fasthttp) -[![GoDoc](https://godoc.org/github.com/valyala/fasthttp?status.svg)](http://godoc.org/github.com/valyala/fasthttp) -[![Go Report](http://goreportcard.com/badge/valyala/fasthttp)](http://goreportcard.com/report/valyala/fasthttp) +# fasthttp [![GoDoc](https://godoc.org/github.com/valyala/fasthttp?status.svg)](http://godoc.org/github.com/valyala/fasthttp) [![Go Report](https://goreportcard.com/badge/github.com/valyala/fasthttp)](https://goreportcard.com/report/github.com/valyala/fasthttp) + +![FastHTTP – Fastest and reliable HTTP implementation in Go](https://github.com/fasthttp/docs-assets/raw/master/banner@0.5.png) -# fasthttp Fast HTTP implementation for Go. +# fasthttp might not be for you! +fasthttp was design for some high performance edge cases. **Unless** your server/client needs to handle **thousands of small to medium requests per seconds** and needs a consistent low millisecond response time fasthttp might not be for you. **For most cases `net/http` is much better** as it's easier to use and can handle more cases. For most cases you won't even notice the performance difference. + + +## General info and links + Currently fasthttp is successfully used by [VertaMedia](https://vertamedia.com/) in a production serving up to 200K rps from more than 1.5M concurrent keep-alive connections per physical server. -[TechEmpower Benchmark round 12 results](https://www.techempower.com/benchmarks/#section=data-r12&hw=peak&test=plaintext) +[TechEmpower Benchmark round 19 results](https://www.techempower.com/benchmarks/#section=data-r19&hw=ph&test=plaintext) [Server Benchmarks](#http-server-performance-comparison-with-nethttp) @@ -23,6 +28,8 @@ [Code examples](examples) +[Awesome fasthttp tools](https://github.com/fasthttp) + [Switching from net/http to fasthttp](#switching-from-nethttp-to-fasthttp) [Fasthttp best practices](#fasthttp-best-practices) @@ -33,7 +40,7 @@ [FAQ](#faq) -# HTTP server performance comparison with [net/http](https://golang.org/pkg/net/http/) +## HTTP server performance comparison with [net/http](https://golang.org/pkg/net/http/) In short, fasthttp server is up to 10 times faster than net/http. Below are benchmark results. @@ -94,7 +101,7 @@ BenchmarkServerGet100ReqPerConn10KClients-4 50000000 282 ns/op 0 B/op 0 allocs/op ``` -# HTTP client comparison with net/http +## HTTP client comparison with net/http In short, fasthttp client is up to 10 times faster than net/http. Below are benchmark results. @@ -156,20 +163,20 @@ ``` -# Install +## Install ``` go get -u github.com/valyala/fasthttp ``` -# Switching from net/http to fasthttp +## Switching from net/http to fasthttp Unfortunately, fasthttp doesn't provide API identical to net/http. See the [FAQ](#faq) for details. There is [net/http -> fasthttp handler converter](https://godoc.org/github.com/valyala/fasthttp/fasthttpadaptor), -but it is advisable writing fasthttp request handlers by hands for gaining -all the fasthttp advantages (especially high performance :) ). +but it is better to write fasthttp request handlers by hand in order to use +all of the fasthttp advantages (especially high performance :) ). Important points: @@ -239,7 +246,7 @@ ``` * Fasthttp allows setting response headers and writing response body -in arbitrary order. There is no 'headers first, then body' restriction +in an arbitrary order. There is no 'headers first, then body' restriction like in net/http. The following code is valid for fasthttp: ```go @@ -273,12 +280,14 @@ * Fasthttp doesn't provide [ServeMux](https://golang.org/pkg/net/http/#ServeMux), but there are more powerful third-party routers and web frameworks -with fasthttp support exist: +with fasthttp support: - * [Iris](https://github.com/kataras/iris) * [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing) - * [fasthttprouter](https://github.com/buaazp/fasthttprouter) - * [echo v2](https://github.com/labstack/echo) + * [router](https://github.com/fasthttp/router) + * [lu](https://github.com/vincentLiuxiang/lu) + * [atreugo](https://github.com/savsgio/atreugo) + * [Fiber](https://github.com/gofiber/fiber) + * [Gearbox](https://github.com/gogearbox/gearbox) Net/http code with simple ServeMux is trivially converted to fasthttp code: @@ -308,7 +317,7 @@ } } - fastttp.ListenAndServe(":80", m) + fasthttp.ListenAndServe(":80", m) ``` * net/http -> fasthttp conversion table: @@ -372,7 +381,7 @@ See [the example](https://godoc.org/github.com/valyala/fasthttp#example-RequestCtx-TimeoutError) for more details. -Use brilliant tool - [race detector](http://blog.golang.org/race-detector) - +Use this brilliant tool - [race detector](http://blog.golang.org/race-detector) - for detecting and eliminating data races in your program. If you detected data race related to fasthttp in your program, then there is high probability you forgot calling [TimeoutError](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.TimeoutError) @@ -390,17 +399,17 @@ [expvarhandler](https://godoc.org/github.com/valyala/fasthttp/expvarhandler). -# Performance optimization tips for multi-core systems +## Performance optimization tips for multi-core systems * Use [reuseport](https://godoc.org/github.com/valyala/fasthttp/reuseport) listener. * Run a separate server instance per CPU core with GOMAXPROCS=1. * Pin each server instance to a separate CPU core using [taskset](http://linux.die.net/man/1/taskset). * Ensure the interrupts of multiqueue network card are evenly distributed between CPU cores. See [this article](https://blog.cloudflare.com/how-to-achieve-low-latency/) for details. -* Use Go 1.6 as it provides some considerable performance improvements. +* Use the latest version of Go as each version contains performance improvements. -# Fasthttp best practices +## Fasthttp best practices * Do not allocate objects and `[]byte` buffers - just reuse them as much as possible. Fasthttp API design encourages this. @@ -421,7 +430,7 @@ [html/template](https://golang.org/pkg/html/template/) in your webserver. -# Tricks with `[]byte` buffers +## Tricks with `[]byte` buffers The following tricks are used by fasthttp. Use them in your code too. @@ -476,36 +485,46 @@ uintBuf := fasthttp.AppendUint(nil, 1234) ``` -# Related projects +## Related projects - * [fasthttp-contrib](https://github.com/fasthttp-contrib) - various useful + * [fasthttp](https://github.com/fasthttp) - various useful helpers for projects based on fasthttp. - * [iris](https://github.com/kataras/iris) - web application framework built - on top of fasthttp. Features speed and functionality. * [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing) - fast and powerful routing package for fasthttp servers. - * [fasthttprouter](https://github.com/buaazp/fasthttprouter) - a high + * [http2](https://github.com/dgrr/http2) - HTTP/2 implementation for fasthttp. + * [router](https://github.com/fasthttp/router) - a high performance fasthttp request router that scales well. - * [echo](https://github.com/labstack/echo) - fast and unfancy HTTP server - framework with fasthttp support. - * [websocket](https://github.com/leavengood/websocket) - Gorilla-based + * [fastws](https://github.com/fasthttp/fastws) - Bloatless WebSocket package made for fasthttp + to handle Read/Write operations concurrently. + * [gramework](https://github.com/gramework/gramework) - a web framework made by one of fasthttp maintainers + * [lu](https://github.com/vincentLiuxiang/lu) - a high performance + go middleware web framework which is based on fasthttp. + * [websocket](https://github.com/fasthttp/websocket) - Gorilla-based websocket implementation for fasthttp. + * [websocket](https://github.com/dgrr/websocket) - Event-based high-performance WebSocket library for zero-allocation + websocket servers and clients. + * [fasthttpsession](https://github.com/phachon/fasthttpsession) - a fast and powerful session package for fasthttp servers. + * [atreugo](https://github.com/savsgio/atreugo) - High performance and extensible micro web framework with zero memory allocations in hot paths. + * [kratgo](https://github.com/savsgio/kratgo) - Simple, lightweight and ultra-fast HTTP Cache to speed up your websites. + * [kit-plugins](https://github.com/wencan/kit-plugins/tree/master/transport/fasthttp) - go-kit transport implementation for fasthttp. + * [Fiber](https://github.com/gofiber/fiber) - An Expressjs inspired web framework running on Fasthttp + * [Gearbox](https://github.com/gogearbox/gearbox) - :gear: gearbox is a web framework written in Go with a focus on high performance and memory optimization -# FAQ +## FAQ * *Why creating yet another http package instead of optimizing net/http?* Because net/http API limits many optimization opportunities. For example: * net/http Request object lifetime isn't limited by request handler execution - time. So the server must create new request object per each request instead - of reusing existing objects like fasthttp do. + time. So the server must create a new request object per each request instead + of reusing existing objects like fasthttp does. * net/http headers are stored in a `map[string][]string`. So the server must parse all the headers, convert them from `[]byte` to `string` and put them into the map before calling user-provided request handler. This all requires unnecessary memory allocations avoided by fasthttp. - * net/http client API requires creating new response object per each request. + * net/http client API requires creating a new response object per each request. * *Why fasthttp API is incompatible with net/http?* @@ -519,10 +538,9 @@ * *Why fasthttp doesn't support HTTP/2.0 and WebSockets?* - There are [plans](TODO) for adding HTTP/2.0 and WebSockets support - in the future. - In the mean time, third parties may use [RequestCtx.Hijack](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Hijack) - for implementing these goodies. See [the first third-party websocket implementation on the top of fasthttp](https://github.com/leavengood/websocket). + [HTTP/2.0 support](https://github.com/fasthttp/http2) is in progress. [WebSockets](https://github.com/fasthttp/websockets) has been done already. + Third parties also may use [RequestCtx.Hijack](https://godoc.org/github.com/valyala/fasthttp#RequestCtx.Hijack) + for implementing these goodies. * *Are there known net/http advantages comparing to fasthttp?* @@ -530,9 +548,10 @@ * net/http supports [HTTP/2.0 starting from go1.6](https://http2.golang.org/). * net/http API is stable, while fasthttp API constantly evolves. * net/http handles more HTTP corner cases. + * net/http can stream both request and response bodies + * net/http can handle bigger bodies as it doesn't read the whole body into memory * net/http should contain less bugs, since it is used and tested by much wider audience. - * net/http works on Go older than 1.5. * *Why fasthttp API prefers returning `[]byte` instead of `string`?* @@ -543,10 +562,9 @@ * *Which GO versions are supported by fasthttp?* - Go1.5+. Older versions won't be supported, since their standard package - [miss useful functions](https://github.com/valyala/fasthttp/issues/5). + Go 1.15.x. Older versions won't be supported. -* *Please provide real benchmark data and sever information* +* *Please provide real benchmark data and server information* See [this issue](https://github.com/valyala/fasthttp/issues/4). @@ -555,10 +573,13 @@ There are no plans to add request routing into fasthttp. Use third-party routers and web frameworks with fasthttp support: - * [Iris](https://github.com/kataras/iris) * [fasthttp-routing](https://github.com/qiangxue/fasthttp-routing) - * [fasthttprouter](https://github.com/buaazp/fasthttprouter) - * [echo v2](https://github.com/labstack/echo) + * [router](https://github.com/fasthttp/router) + * [gramework](https://github.com/gramework/gramework) + * [lu](https://github.com/vincentLiuxiang/lu) + * [atreugo](https://github.com/savsgio/atreugo) + * [Fiber](https://github.com/gofiber/fiber) + * [Gearbox](https://github.com/gogearbox/gearbox) See also [this issue](https://github.com/valyala/fasthttp/issues/9) for more info. diff -Nru golang-github-valyala-fasthttp-20160617/reuseport/reuseport_bsd.go golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_bsd.go --- golang-github-valyala-fasthttp-20160617/reuseport/reuseport_bsd.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_bsd.go 1970-01-01 00:00:00.000000000 +0000 @@ -1,9 +0,0 @@ -// +build darwin dragonfly freebsd netbsd openbsd rumprun - -package reuseport - -import ( - "syscall" -) - -const soReusePort = syscall.SO_REUSEPORT diff -Nru golang-github-valyala-fasthttp-20160617/reuseport/reuseport_error.go golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_error.go --- golang-github-valyala-fasthttp-20160617/reuseport/reuseport_error.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_error.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,15 @@ +package reuseport + +import ( + "fmt" +) + +// ErrNoReusePort is returned if the OS doesn't support SO_REUSEPORT. +type ErrNoReusePort struct { + err error +} + +// Error implements error interface. +func (e *ErrNoReusePort) Error() string { + return fmt.Sprintf("The OS doesn't support SO_REUSEPORT: %s", e.err) +} diff -Nru golang-github-valyala-fasthttp-20160617/reuseport/reuseport.go golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport.go --- golang-github-valyala-fasthttp-20160617/reuseport/reuseport.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,4 +1,5 @@ -// +build linux darwin dragonfly freebsd netbsd openbsd rumprun +//go:build !windows +// +build !windows // Package reuseport provides TCP net.Listener with SO_REUSEPORT support. // @@ -9,98 +10,38 @@ package reuseport import ( - "errors" - "fmt" "net" - "os" - "syscall" -) - -func getSockaddr(network, addr string) (sa syscall.Sockaddr, soType int, err error) { - // TODO: add support for tcp and tcp6 networks. + "strings" - if network != "tcp4" { - return nil, -1, errors.New("only tcp4 network is supported") - } - - tcpAddr, err := net.ResolveTCPAddr(network, addr) - if err != nil { - return nil, -1, err - } - - var sa4 syscall.SockaddrInet4 - sa4.Port = tcpAddr.Port - copy(sa4.Addr[:], tcpAddr.IP.To4()) - return &sa4, syscall.AF_INET, nil -} - -// ErrNoReusePort is returned if the OS doesn't support SO_REUSEPORT. -type ErrNoReusePort struct { - err error -} - -// Error implements error interface. -func (e *ErrNoReusePort) Error() string { - return fmt.Sprintf("The OS doesn't support SO_REUSEPORT: %s", e.err) -} + "github.com/valyala/tcplisten" +) // Listen returns TCP listener with SO_REUSEPORT option set. // -// Only tcp4 network is supported. +// The returned listener tries enabling the following TCP options, which usually +// have positive impact on performance: +// +// - TCP_DEFER_ACCEPT. This option expects that the server reads from accepted +// connections before writing to them. +// +// - TCP_FASTOPEN. See https://lwn.net/Articles/508865/ for details. +// +// Use https://github.com/valyala/tcplisten if you want customizing +// these options. +// +// Only tcp4 and tcp6 networks are supported. // // ErrNoReusePort error is returned if the system doesn't support SO_REUSEPORT. -func Listen(network, addr string) (l net.Listener, err error) { - var ( - soType, fd int - file *os.File - sockaddr syscall.Sockaddr - ) - - if sockaddr, soType, err = getSockaddr(network, addr); err != nil { - return nil, err - } - - syscall.ForkLock.RLock() - fd, err = syscall.Socket(soType, syscall.SOCK_STREAM, syscall.IPPROTO_TCP) - if err == nil { - syscall.CloseOnExec(fd) - } - syscall.ForkLock.RUnlock() - if err != nil { - return nil, err - } - - if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil { - syscall.Close(fd) - return nil, err - } - - if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, soReusePort, 1); err != nil { - syscall.Close(fd) +func Listen(network, addr string) (net.Listener, error) { + ln, err := cfg.NewListener(network, addr) + if err != nil && strings.Contains(err.Error(), "SO_REUSEPORT") { return nil, &ErrNoReusePort{err} } + return ln, err +} - if err = syscall.Bind(fd, sockaddr); err != nil { - syscall.Close(fd) - return nil, err - } - - if err = syscall.Listen(fd, syscall.SOMAXCONN); err != nil { - syscall.Close(fd) - return nil, err - } - - name := fmt.Sprintf("reuseport.%d.%s.%s", os.Getpid(), network, addr) - file = os.NewFile(uintptr(fd), name) - if l, err = net.FileListener(file); err != nil { - file.Close() - return nil, err - } - - if err = file.Close(); err != nil { - l.Close() - return nil, err - } - - return l, err +var cfg = &tcplisten.Config{ + ReusePort: true, + DeferAccept: true, + FastOpen: true, } diff -Nru golang-github-valyala-fasthttp-20160617/reuseport/reuseport_linux.go golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_linux.go --- golang-github-valyala-fasthttp-20160617/reuseport/reuseport_linux.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_linux.go 1970-01-01 00:00:00.000000000 +0000 @@ -1,5 +0,0 @@ -// +build linux - -package reuseport - -const soReusePort = 0x0F diff -Nru golang-github-valyala-fasthttp-20160617/reuseport/reuseport_test.go golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_test.go --- golang-github-valyala-fasthttp-20160617/reuseport/reuseport_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -8,16 +8,40 @@ "time" ) -func TestNewListener(t *testing.T) { - addr := "localhost:10081" - serversCount := 20 - requestsCount := 1000 +func TestTCP4(t *testing.T) { + t.Parallel() + testNewListener(t, "tcp4", "localhost:10081", 20, 1000) +} + +func TestTCP6(t *testing.T) { + t.Parallel() + + // Run this test only if tcp6 interface exists. + if hasLocalIPv6(t) { + testNewListener(t, "tcp6", "[::1]:10082", 20, 1000) + } +} + +func hasLocalIPv6(t *testing.T) bool { + addrs, err := net.InterfaceAddrs() + if err != nil { + t.Fatalf("cannot obtain local interfaces: %s", err) + } + for _, a := range addrs { + if a.String() == "::1/128" { + return true + } + } + return false +} + +func testNewListener(t *testing.T, network, addr string, serversCount, requestsCount int) { var lns []net.Listener doneCh := make(chan struct{}, serversCount) for i := 0; i < serversCount; i++ { - ln, err := Listen("tcp4", addr) + ln, err := Listen(network, addr) if err != nil { t.Fatalf("cannot create listener %d: %s", i, err) } @@ -29,7 +53,7 @@ } for i := 0; i < requestsCount; i++ { - c, err := net.Dial("tcp4", addr) + c, err := net.Dial(network, addr) if err != nil { t.Fatalf("%d. unexpected error when dialing: %s", i, err) } @@ -45,14 +69,14 @@ ch := make(chan struct{}) go func() { if resp, err = ioutil.ReadAll(c); err != nil { - t.Fatalf("%d. unexpected error when reading response: %s", i, err) + t.Errorf("%d. unexpected error when reading response: %s", i, err) } close(ch) }() select { case <-ch: - case <-time.After(200 * time.Millisecond): - t.Fatalf("%d. timeout when waiting for response: %s", i, err) + case <-time.After(250 * time.Millisecond): + t.Fatalf("%d. timeout when waiting for response", i) } if string(resp) != req { @@ -86,7 +110,7 @@ } req, err := ioutil.ReadAll(c) if err != nil { - t.Fatalf("unepxected error when reading request: %s", err) + t.Fatalf("unexpected error when reading request: %s", err) } if _, err = c.Write(req); err != nil { t.Fatalf("unexpected error when writing response: %s", err) diff -Nru golang-github-valyala-fasthttp-20160617/reuseport/reuseport_windows.go golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_windows.go --- golang-github-valyala-fasthttp-20160617/reuseport/reuseport_windows.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/reuseport/reuseport_windows.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,23 @@ +package reuseport + +import ( + "context" + "net" + "syscall" + + "golang.org/x/sys/windows" +) + +var listenConfig = net.ListenConfig{ + Control: func(network, address string, c syscall.RawConn) (err error) { + return c.Control(func(fd uintptr) { + err = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_REUSEADDR, 1) + }) + }, +} + +// Listen returns TCP listener with SO_REUSEADDR option set, SO_REUSEPORT is not supported on Windows, so it uses +// SO_REUSEADDR as an alternative to achieve the same effect. +func Listen(network, addr string) (net.Listener, error) { + return listenConfig.Listen(context.Background(), network, addr) +} diff -Nru golang-github-valyala-fasthttp-20160617/SECURITY.md golang-github-valyala-fasthttp-1.31.0/SECURITY.md --- golang-github-valyala-fasthttp-20160617/SECURITY.md 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/SECURITY.md 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,115 @@ +### TL;DR + +We use a simplified version of [Golang Security Policy](https://golang.org/security). +For example, for now we skip CVE assignment. + +### Reporting a Security Bug + +Please report to us any issues you find. This document explains how to do that and what to expect in return. + +All security bugs in our releases should be reported by email to oss-security@highload.solutions. +This mail is delivered to a small security team. +Your email will be acknowledged within 24 hours, and you'll receive a more detailed response +to your email within 72 hours indicating the next steps in handling your report. +For critical problems, you can encrypt your report using our PGP key (listed below). + +Please use a descriptive subject line for your report email. +After the initial reply to your report, the security team will +endeavor to keep you informed of the progress being made towards a fix and full announcement. +These updates will be sent at least every five days. +In reality, this is more likely to be every 24-48 hours. + +If you have not received a reply to your email within 48 hours or you have not heard from the security +team for the past five days please contact us by email to developers@highload.solutions or by Telegram message +to [our support](https://t.me/highload_support). +Please note that developers@highload.solutions list includes all developers, who may be outside our opensource security team. +When escalating on this list, please do not disclose the details of the issue. +Simply state that you're trying to reach a member of the security team. + +### Flagging Existing Issues as Security-related + +If you believe that an existing issue is security-related, we ask that you send an email to oss-security@highload.solutions. +The email should include the issue ID and a short description of why it should be handled according to this security policy. + +### Disclosure Process + +Our project uses the following disclosure process: + +- Once the security report is received it is assigned a primary handler. This person coordinates the fix and release process. +- The issue is confirmed and a list of affected software is determined. +- Code is audited to find any potential similar problems. +- Fixes are prepared for the two most recent major releases and the head/master revision. These fixes are not yet committed to the public repository. +- To notify users, a new issue without security details is submitted to our GitHub repository. +- Three working days following this notification, the fixes are applied to the public repository and a new release is issued. +- On the date that the fixes are applied, announcement is published in the issue. + +This process can take some time, especially when coordination is required with maintainers of other projects. +Every effort will be made to handle the bug in as timely a manner as possible, however it's important that we follow +the process described above to ensure that disclosures are handled consistently. + +### Receiving Security Updates +The best way to receive security announcements is to subscribe ("Watch") to our repository. +Any GitHub issues pertaining to a security issue will be prefixed with [security]. + +### Comments on This Policy +If you have any suggestions to improve this policy, please send an email to oss-security@highload.solutions for discussion. + +### PGP Key for oss-security@highload.ltd + +We accept PGP-encrypted email, but the majority of the security team are not regular PGP users +so it's somewhat inconvenient. Please only use PGP for critical security reports. + +``` +-----BEGIN PGP PUBLIC KEY BLOCK----- + +mQINBFzdjYUBEACa3YN+QVSlnXofUjxr+YrmIaF+da0IUq+TRM4aqUXALsemEdGh +yIl7Z6qOOy1d2kPe6t//H9l/92lJ1X7i6aEBK4n/pnPZkwbpy9gGpebgvTZFvcbe +mFhF6k1FM35D8TxneJSjizPyGhJPqcr5qccqf8R64TlQx5Ud1JqT2l8P1C5N7gNS +lEYXq1h4zBCvTWk1wdeLRRPx7Bn6xrgmyu/k61dLoJDvpvWNATVFDA67oTrPgzTW +xtLbbk/xm0mK4a8zMzIpNyz1WkaJW9+4HFXaL+yKlsx7iHe2O7VlGoqS0kdeQup4 +1HIw/P7yc0jBlNMLUzpuA6ElYUwESWsnCI71YY1x4rKgI+GqH1mWwgn7tteuXQtb +Zj0vEdjK3IKIOSbzbzAvSbDt8F1+o7EMtdy1eUysjKSQgFkDlT6JRmYvEup5/IoG +iknh/InQq9RmGFKii6pXWWoltC0ebfCwYOXvymyDdr/hYDqJeHS9Tenpy86Doaaf +HGf5nIFAMB2G5ctNpBwzNXR2MAWkeHQgdr5a1xmog0hS125usjnUTet3QeCyo4kd +gVouoOroMcqFFUXdYaMH4c3KWz0afhTmIaAsFFOv/eMdadVA4QyExTJf3TAoQ+kH +lKDlbOAIxEZWRPDFxMRixaVPQC+VxhBcaQ+yNoaUkM0V2m8u8sDBpzi1OQARAQAB +tDxPU1MgU2VjdXJpdHksIEhpZ2hsb2FkIExURCA8b3NzLXNlY3VyaXR5QGhpZ2hs +b2FkLnNvbHV0aW9ucz6JAlQEEwEIAD4WIQRljYp380uKq2g8TeqsQcvu+Qp2TAUC +XN2NhQIbAwUJB4YfgAULCQgHAgYVCgkICwIEFgIDAQIeAQIXgAAKCRCsQcvu+Qp2 +TKmED/96YoQoOjD28blFFrigvAsiNcNNZoX9I0dX1lNpD83fBJf+/9i+x4jqUnI5 +5XK/DFTDbhpw8kQBpxS9eEuIYnuo0RdLLp1ctNWTlpwfyHn92mGddl/uBdYHUuUk +cjhIQcFaCcWRY+EpamDlv1wmZ83IwBr8Hu5FS+/Msyw1TBvtTRVKW1KoGYMYoXLk +BzIglRPwn821B6s4BvK/RJnZkrmHMBZBfYMf+iSMSYd2yPmfT8wbcAjgjLfQa28U +gbt4u9xslgKjuM83IqwFfEXBnm7su3OouGWqc+62mQTsbnK65zRFnx6GXRXC1BAi +6m9Tm1PU0IiINz66ainquspkXYeHjd9hTwfR3BdFnzBTRRM01cKMFabWbLj8j0p8 +fF4g9cxEdiLrzEF7Yz4WY0mI4Cpw4eJZfsHMc07Jn7QxfJhIoq+rqBOtEmTjnxMh +aWeykoXMHlZN4K0ZrAytozVH1D4bugWA9Zuzi9U3F9hrVVABm11yyhd2iSqI6/FR +GcCFOCBW1kEJbzoEguub+BV8LDi8ldljHalvur5k/VFhoDBxniYNsKmiCLVCmDWs +/nF84hCReAOJt0vDGwqHe3E2BFFPbKwdJLRNkjxBY0c/pvaV+JxbWQmaxDZNeIFV +hFcVGp48HNY3qLWZdsQIfT9m1masJFLVuq8Wx7bYs8Et5eFnH7kCDQRc3Y2FARAA +2DJWAxABydyIdCxgFNdqnYyWS46vh2DmLmRMqgasNlD0ozG4S9bszBsgnUI2Xs06 +J76kFRh8MMHcu9I4lUKCQzfrA4uHkiOK5wvNCaWP+C6JUYNHsqPwk/ILO3gtQ/Ws +LLf/PW3rJZVOZB+WY8iaYc20l5vukTaVw4qbEi9dtLkJvVpNHt//+jayXU6s3ew1 +2X5xdwyAZxaxlnzFaY/Xo/qR+bZhVFC0T9pAECnHv9TVhFGp0JE9ipPGnro5xTIS +LttdAkzv4AuSVTIgWgTkh8nN8t7STJqfPEv0I12nmmYHMUyTYOurkfskF3jY2x6x +8l02NQ4d5KdC3ReV1j51swrGcZCwsWNp51jnEXKwo+B0NM5OmoRrNJgF2iDgLehs +hP00ljU7cB8/1/7kdHZStYaUHICFOFqHzg415FlYm+jpY0nJp/b9BAO0d0/WYnEe +Xjihw8EVBAqzEt4kay1BQonZAypeYnGBJr7vNvdiP+mnRwly5qZSGiInxGvtZZFt +zL1E3osiF+muQxFcM63BeGdJeYXy+MoczkWa4WNggfcHlGAZkMYiv28zpr4PfrK9 +mvj4Nu8s71PE9pPpBoZcNDf9v1sHuu96jDSITsPx5YMvvKZWhzJXFKzk6YgAsNH/ +MF0G+/qmKJZpCdvtHKpYM1uHX85H81CwWJFfBPthyD8AEQEAAYkCPAQYAQgAJhYh +BGWNinfzS4qraDxN6qxBy+75CnZMBQJc3Y2FAhsMBQkHhh+AAAoJEKxBy+75CnZM +Rn8P/RyL1bhU4Q4WpvmlkepCAwNA0G3QvnKcSZNHEPE5h7H3IyrA/qy16A9eOsgm +sthsHYlo5A5lRIy4wPHkFCClMrMHdKuoS72//qgw+oOrBcwb7Te+Nas+ewhaJ7N9 +vAX06vDH9bLl52CPbtats5+eBpePgP3HDPxd7CWHxq9bzJTbzqsTkN7JvoovR2dP +itPJDij7QYLYVEM1t7QxUVpVwAjDi/kCtC9ts5L+V0snF2n3bHZvu04EXdpvxOQI +pG/7Q+/WoI8NU6Bb/FA3tJGYIhSwI3SY+5XV/TAZttZaYSh2SD8vhc+eo+gW9sAN +xa+VESBQCht9+tKIwEwHs1efoRgFdbwwJ2c+33+XydQ6yjdXoX1mn2uyCr82jorZ +xTzbkY04zr7oZ+0fLpouOFg/mrSL4w2bWEhdHuyoVthLBjnRme0wXCaS3g3mYdLG +nSUkogOGOOvvvBtoq/vfx0Eu79piUtw5D8yQSrxLDuz8GxCrVRZ0tYIHb26aTE9G +cDsW/Lg5PjcY/LgVNEWOxDQDFVurlImnlVJFb3q+NrWvPbgeIEWwJDCay/z25SEH +k3bSOXLp8YGRnlkWUmoeL4g/CCK52iAAlfscZNoKMILhBnbCoD657jpa5GQKJj/U +Q8kjgr7kwV/RSosNV9HCPj30mVyiCQ1xg+ZLzMKXVCuBWd+G +=lnt2 +-----END PGP PUBLIC KEY BLOCK----- +``` diff -Nru golang-github-valyala-fasthttp-20160617/server.go golang-github-valyala-fasthttp-1.31.0/server.go --- golang-github-valyala-fasthttp-20160617/server.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/server.go 2021-10-09 18:39:05.000000000 +0000 @@ -2,6 +2,7 @@ import ( "bufio" + "context" "crypto/tls" "errors" "fmt" @@ -10,13 +11,20 @@ "mime/multipart" "net" "os" - "runtime/debug" "strings" "sync" "sync/atomic" "time" ) +var errNoCertOrKeyProvided = errors.New("cert or key has not provided") + +var ( + // ErrAlreadyServing is returned when calling Serve on a Server + // that is already serving connections. + ErrAlreadyServing = errors.New("Server is already serving connections") +) + // ServeConn serves HTTP requests from the given connection // using the given handler. // @@ -127,6 +135,9 @@ // must be limited. type RequestHandler func(ctx *RequestCtx) +// ServeHandler must process tls.Config.NextProto negotiated requests. +type ServeHandler func(c net.Conn) error + // Server implements HTTP server. // // Default Server settings should satisfy the majority of Server users. @@ -137,11 +148,41 @@ // // It is safe to call Server methods from concurrently running goroutines. type Server struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck // Handler for processing incoming requests. + // + // Take into account that no `panic` recovery is done by `fasthttp` (thus any `panic` will take down the entire server). + // Instead the user should use `recover` to handle these situations. Handler RequestHandler + // ErrorHandler for returning a response in case of an error while receiving or parsing the request. + // + // The following is a non-exhaustive list of errors that can be expected as argument: + // * io.EOF + // * io.ErrUnexpectedEOF + // * ErrGetOnly + // * ErrSmallBuffer + // * ErrBodyTooLarge + // * ErrBrokenChunks + ErrorHandler func(ctx *RequestCtx, err error) + + // HeaderReceived is called after receiving the header + // + // non zero RequestConfig field values will overwrite the default configs + HeaderReceived func(header *RequestHeader) RequestConfig + + // ContinueHandler is called after receiving the Expect 100 Continue Header + // + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.1.1 + // Using ContinueHandler a server can make decisioning on whether or not + // to read a potentially large request body based on the headers + // + // The default is to automatically read request bodies of Expect 100 Continue requests + // like they are normal requests + ContinueHandler func(header *RequestHeader) bool + // Server name for sending in response headers. // // Default server name is used if left blank. @@ -150,15 +191,10 @@ // The maximum number of concurrent connections the server may serve. // // DefaultConcurrency is used if not set. - Concurrency int - - // Whether to disable keep-alive connections. // - // The server will close all the incoming connections after sending - // the first response to client if this option is set to true. - // - // By default keep-alive connections are enabled. - DisableKeepalive bool + // Concurrency only works if you either call Serve once, or only ServeConn multiple times. + // It works with ListenAndServe as well. + Concurrency int // Per-connection buffer size for requests' reading. // This also limits the maximum header size. @@ -174,19 +210,26 @@ // Default buffer size is used if not set. WriteBufferSize int - // Maximum duration for reading the full request (including body). - // - // This also limits the maximum duration for idle keep-alive - // connections. + // ReadTimeout is the amount of time allowed to read + // the full request including body. The connection's read + // deadline is reset when the connection opens, or for + // keep-alive connections after the first byte has been read. // // By default request read timeout is unlimited. ReadTimeout time.Duration - // Maximum duration for writing the full response (including body). + // WriteTimeout is the maximum duration before timing out + // writes of the response. It is reset after the request handler + // has returned. // // By default response write timeout is unlimited. WriteTimeout time.Duration + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alive is enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. + IdleTimeout time.Duration + // Maximum number of concurrent client connections allowed per IP. // // By default unlimited number of concurrent connections @@ -201,30 +244,43 @@ // By default unlimited number of requests may be served per connection. MaxRequestsPerConn int - // Maximum keep-alive connection lifetime. - // - // The server closes keep-alive connection after its' lifetime - // expiration. - // - // See also ReadTimeout for limiting the duration of idle keep-alive - // connections. - // - // By default keep-alive connection lifetime is unlimited. + // MaxKeepaliveDuration is a no-op and only left here for backwards compatibility. + // Deprecated: Use IdleTimeout instead. MaxKeepaliveDuration time.Duration + // Period between tcp keep-alive messages. + // + // TCP keep-alive period is determined by operation system by default. + TCPKeepalivePeriod time.Duration + // Maximum request body size. // // The server rejects requests with bodies exceeding this limit. // - // By default request body size is unlimited. + // Request body size is limited by DefaultMaxRequestBodySize by default. MaxRequestBodySize int + // Whether to disable keep-alive connections. + // + // The server will close all the incoming connections after sending + // the first response to client if this option is set to true. + // + // By default keep-alive connections are enabled. + DisableKeepalive bool + + // Whether to enable tcp keep-alive connections. + // + // Whether the operating system should send tcp keep-alive messages on the tcp connection. + // + // By default tcp keep-alive connections are disabled. + TCPKeepalive bool + // Aggressively reduces memory usage at the cost of higher CPU usage // if set to true. // // Try enabling this option only if the server consumes too much memory - // serving mostly idle keep-alive connections (more than 1M concurrent - // connections). This may reduce memory usage by up to 50%. + // serving mostly idle keep-alive connections. This may reduce memory + // usage by more than 50%. // // Aggressive memory usage reduction is disabled by default. ReduceMemoryUsage bool @@ -238,6 +294,14 @@ // Server accepts all the requests by default. GetOnly bool + // Will not pre parse Multipart Form data if set to true. + // + // This option is useful for servers that desire to treat + // multipart form data as a binary blob, or choose when to parse the data. + // + // Server pre parses multipart form data by default. + DisablePreParseMultipartForm bool + // Logs all errors, including the most frequent // 'connection reset by peer', 'broken pipe' and 'connection timeout' // errors. Such errors are common in production serving real-world @@ -248,6 +312,14 @@ // are suppressed in order to limit output log traffic. LogAllErrors bool + // Will not log potentially sensitive content in error logs + // + // This option is useful for servers that handle sensitive data + // in the request/response. + // + // Server logs all full errors by default. + SecureErrorLogMessage bool + // Header names are passed as-is without normalization // if this option is set. // @@ -266,11 +338,62 @@ // * cONTENT-lenGTH -> Content-Length DisableHeaderNamesNormalizing bool + // SleepWhenConcurrencyLimitsExceeded is a duration to be slept of if + // the concurrency limit in exceeded (default [when is 0]: don't sleep + // and accept new connections immediately). + SleepWhenConcurrencyLimitsExceeded time.Duration + + // NoDefaultServerHeader, when set to true, causes the default Server header + // to be excluded from the Response. + // + // The default Server header value is the value of the Name field or an + // internal default value in its absence. With this option set to true, + // the only time a Server header will be sent is if a non-zero length + // value is explicitly provided during a request. + NoDefaultServerHeader bool + + // NoDefaultDate, when set to true, causes the default Date + // header to be excluded from the Response. + // + // The default Date header value is the current date value. When + // set to true, the Date will not be present. + NoDefaultDate bool + + // NoDefaultContentType, when set to true, causes the default Content-Type + // header to be excluded from the Response. + // + // The default Content-Type header value is the internal default value. When + // set to true, the Content-Type will not be present. + NoDefaultContentType bool + + // KeepHijackedConns is an opt-in disable of connection + // close by fasthttp after connections' HijackHandler returns. + // This allows to save goroutines, e.g. when fasthttp used to upgrade + // http connections to WS and connection goes to another handler, + // which will close it when needed. + KeepHijackedConns bool + + // CloseOnShutdown when true adds a `Connection: close` header when when the server is shutting down. + CloseOnShutdown bool + + // StreamRequestBody enables request body streaming, + // and calls the handler sooner when given body is + // larger then the current limit. + StreamRequestBody bool + + // ConnState specifies an optional callback function that is + // called when a client connection changes state. See the + // ConnState type and associated constants for details. + ConnState func(net.Conn, ConnState) + // Logger, which is used by RequestCtx.Logger(). // // By default standard logger from log package is used. Logger Logger + tlsConfig *tls.Config + nextProtos map[string]ServeHandler + concurrency uint32 concurrencyCh chan struct{} perIPConnCounter perIPConnCounter @@ -280,7 +403,14 @@ readerPool sync.Pool writerPool sync.Pool hijackConnPool sync.Pool - bytePool sync.Pool + + // We need to know our listeners so we can close them in Shutdown(). + ln []net.Listener + + mu sync.Mutex + open int32 + stop int32 + done chan struct{} } // TimeoutHandler creates RequestHandler, which returns StatusRequestTimeout @@ -291,6 +421,17 @@ // msg to the client if there are more than Server.Concurrency concurrent // handlers h are running at the moment. func TimeoutHandler(h RequestHandler, timeout time.Duration, msg string) RequestHandler { + return TimeoutWithCodeHandler(h, timeout, msg, StatusRequestTimeout) +} + +// TimeoutWithCodeHandler creates RequestHandler, which returns an error with +// the given msg and status code to the client if h didn't return during +// the given duration. +// +// The returned handler may return StatusTooManyRequests error with the given +// msg to the client if there are more than Server.Concurrency concurrent +// handlers h are running at the moment. +func TimeoutWithCodeHandler(h RequestHandler, timeout time.Duration, msg string, statusCode int) RequestHandler { if timeout <= 0 { return h } @@ -318,12 +459,27 @@ select { case <-ch: case <-ctx.timeoutTimer.C: - ctx.TimeoutError(msg) + ctx.TimeoutErrorWithCode(msg, statusCode) } stopTimer(ctx.timeoutTimer) } } +//RequestConfig configure the per request deadline and body limits +type RequestConfig struct { + // ReadTimeout is the maximum duration for reading the entire + // request body. + // a zero value means that default values will be honored + ReadTimeout time.Duration + // WriteTimeout is the maximum duration before timing out + // writes of the response. + // a zero value means that default values will be honored + WriteTimeout time.Duration + // Maximum request body size. + // a zero value means that default values will be honored + MaxRequestBodySize int +} + // CompressHandler returns RequestHandler that transparently compresses // response body generated by h if the request contains 'gzip' or 'deflate' // 'Accept-Encoding' header. @@ -332,7 +488,7 @@ } // CompressHandlerLevel returns RequestHandler that transparently compresses -// response body generated by h if the request contains 'gzip' or 'deflate' +// response body generated by h if the request contains a 'gzip' or 'deflate' // 'Accept-Encoding' header. // // Level is the desired compression level: @@ -341,19 +497,45 @@ // * CompressBestSpeed // * CompressBestCompression // * CompressDefaultCompression +// * CompressHuffmanOnly func CompressHandlerLevel(h RequestHandler, level int) RequestHandler { return func(ctx *RequestCtx) { h(ctx) - ce := ctx.Response.Header.PeekBytes(strContentEncoding) - if len(ce) > 0 { - // Do not compress responses with non-empty - // Content-Encoding. - return - } if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { - ctx.Response.gzipBody(level) + ctx.Response.gzipBody(level) //nolint:errcheck } else if ctx.Request.Header.HasAcceptEncodingBytes(strDeflate) { - ctx.Response.deflateBody(level) + ctx.Response.deflateBody(level) //nolint:errcheck + } + } +} + +// CompressHandlerBrotliLevel returns RequestHandler that transparently compresses +// response body generated by h if the request contains a 'br', 'gzip' or 'deflate' +// 'Accept-Encoding' header. +// +// brotliLevel is the desired compression level for brotli. +// +// * CompressBrotliNoCompression +// * CompressBrotliBestSpeed +// * CompressBrotliBestCompression +// * CompressBrotliDefaultCompression +// +// otherLevel is the desired compression level for gzip and deflate. +// +// * CompressNoCompression +// * CompressBestSpeed +// * CompressBestCompression +// * CompressDefaultCompression +// * CompressHuffmanOnly +func CompressHandlerBrotliLevel(h RequestHandler, brotliLevel, otherLevel int) RequestHandler { + return func(ctx *RequestCtx) { + h(ctx) + if ctx.Request.Header.HasAcceptEncodingBytes(strBr) { + ctx.Response.brotliBody(brotliLevel) //nolint:errcheck + } else if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { + ctx.Response.gzipBody(otherLevel) //nolint:errcheck + } else if ctx.Request.Header.HasAcceptEncodingBytes(strDeflate) { + ctx.Response.deflateBody(otherLevel) //nolint:errcheck } } } @@ -373,7 +555,7 @@ // running goroutines. The only exception is TimeoutError*, which may be called // while other goroutines accessing RequestCtx. type RequestCtx struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck // Incoming request. // @@ -387,11 +569,10 @@ userValues userData - lastReadDuration time.Duration - connID uint64 connRequestNum uint64 connTime time.Time + remoteAddr net.Addr time time.Time @@ -404,14 +585,19 @@ timeoutCh chan struct{} timeoutTimer *time.Timer - hijackHandler HijackHandler + hijackHandler HijackHandler + hijackNoResponse bool } // HijackHandler must process the hijacked connection c. // -// The connection c is automatically closed after returning from HijackHandler. +// If KeepHijackedConns is disabled, which is by default, +// the connection c is automatically closed after returning from HijackHandler. // -// The connection c must not be used after returning from the handler. +// The connection c must not be used after returning from the handler, if KeepHijackedConns is disabled. +// +// When KeepHijackedConns enabled, fasthttp will not Close() the connection, +// you must do it when you need it. You must not use c in any way after calling Close(). type HijackHandler func(c net.Conn) // Hijack registers the given handler for connection hijacking. @@ -427,6 +613,7 @@ // * Unexpected error during response writing to the connection. // // The server stops processing requests from hijacked connections. +// // Server limits such as Concurrency, ReadTimeout, WriteTimeout, etc. // aren't applied to hijacked connections. // @@ -442,6 +629,20 @@ ctx.hijackHandler = handler } +// HijackSetNoResponse changes the behavior of hijacking a request. +// If HijackSetNoResponse is called with false fasthttp will send a response +// to the client before calling the HijackHandler (default). If HijackSetNoResponse +// is called with true no response is send back before calling the +// HijackHandler supplied in the Hijack function. +func (ctx *RequestCtx) HijackSetNoResponse(noResponse bool) { + ctx.hijackNoResponse = noResponse +} + +// Hijacked returns true after Hijack is called. +func (ctx *RequestCtx) Hijacked() bool { + return ctx.hijackHandler != nil +} + // SetUserValue stores the given value (arbitrary object) // under the given key in ctx. // @@ -481,11 +682,57 @@ return ctx.userValues.GetBytes(key) } +// VisitUserValues calls visitor for each existing userValue. +// +// visitor must not retain references to key and value after returning. +// Make key and/or value copies if you need storing them after returning. +func (ctx *RequestCtx) VisitUserValues(visitor func([]byte, interface{})) { + for i, n := 0, len(ctx.userValues); i < n; i++ { + kv := &ctx.userValues[i] + visitor(kv.key, kv.value) + } +} + +// ResetUserValues allows to reset user values from Request Context +func (ctx *RequestCtx) ResetUserValues() { + ctx.userValues.Reset() +} + +// RemoveUserValue removes the given key and the value under it in ctx. +func (ctx *RequestCtx) RemoveUserValue(key string) { + ctx.userValues.Remove(key) +} + +// RemoveUserValueBytes removes the given key and the value under it in ctx. +func (ctx *RequestCtx) RemoveUserValueBytes(key []byte) { + ctx.userValues.RemoveBytes(key) +} + +type connTLSer interface { + Handshake() error + ConnectionState() tls.ConnectionState +} + // IsTLS returns true if the underlying connection is tls.Conn. // // tls.Conn is an encrypted connection (aka SSL, HTTPS). func (ctx *RequestCtx) IsTLS() bool { - _, ok := ctx.c.(*tls.Conn) + // cast to (connTLSer) instead of (*tls.Conn), since it catches + // cases with overridden tls.Conn such as: + // + // type customConn struct { + // *tls.Conn + // + // // other custom fields here + // } + + // perIPConn wraps the net.Conn in the Conn field + if pic, ok := ctx.c.(*perIPConn); ok { + _, ok := pic.Conn.(connTLSer) + return ok + } + + _, ok := ctx.c.(connTLSer) return ok } @@ -496,7 +743,7 @@ // The returned state may be used for verifying TLS version, client certificates, // etc. func (ctx *RequestCtx) TLSConnectionState() *tls.ConnectionState { - tlsConn, ok := ctx.c.(*tls.Conn) + tlsConn, ok := ctx.c.(connTLSer) if !ok { return nil } @@ -504,6 +751,15 @@ return &state } +// Conn returns a reference to the underlying net.Conn. +// +// WARNING: Only use this method if you know what you are doing! +// +// Reading from or writing to the returned connection will end badly! +func (ctx *RequestCtx) Conn() net.Conn { + return ctx.c +} + type firstByteReader struct { c net.Conn ch byte @@ -539,12 +795,9 @@ } func (cl *ctxLogger) Printf(format string, args ...interface{}) { - ctxLoggerLock.Lock() msg := fmt.Sprintf(format, args...) - ctx := cl.ctx - req := &ctx.Request - cl.logger.Printf("%.3f #%016X - %s<->%s - %s %s - %s", - time.Since(ctx.Time()).Seconds(), ctx.ID(), ctx.LocalAddr(), ctx.RemoteAddr(), req.Header.Method(), ctx.URI().FullURI(), msg) + ctxLoggerLock.Lock() + cl.logger.Printf("%.3f %s - %s", time.Since(cl.ctx.ConnTime()).Seconds(), cl.ctx.String(), msg) ctxLoggerLock.Unlock() } @@ -552,6 +805,13 @@ IP: net.IPv4zero, } +// String returns unique string representation of the ctx. +// +// The returned value may be useful for logging. +func (ctx *RequestCtx) String() string { + return fmt.Sprintf("#%016X - %s<->%s - %s %s", ctx.ID(), ctx.LocalAddr(), ctx.RemoteAddr(), ctx.Request.Header.Method(), ctx.URI().FullURI()) +} + // ID returns unique ID of the request. func (ctx *RequestCtx) ID() uint64 { return (ctx.connID << 32) | ctx.connRequestNum @@ -570,7 +830,7 @@ return ctx.time } -// ConnTime returns the time server starts serving the connection +// ConnTime returns the time the server started serving the connection // the current request came from. func (ctx *RequestCtx) ConnTime() time.Time { return ctx.connTime @@ -578,6 +838,8 @@ // ConnRequestNum returns request sequence number // for the current connection. +// +// Sequence starts with 1. func (ctx *RequestCtx) ConnRequestNum() uint64 { return ctx.connRequestNum } @@ -607,40 +869,42 @@ // RequestURI returns RequestURI. // -// This uri is valid until returning from RequestHandler. +// The returned bytes are valid until your request handler returns. func (ctx *RequestCtx) RequestURI() []byte { return ctx.Request.Header.RequestURI() } // URI returns requested uri. // -// The uri is valid until returning from RequestHandler. +// This uri is valid until your request handler returns. func (ctx *RequestCtx) URI() *URI { return ctx.Request.URI() } // Referer returns request referer. // -// The referer is valid until returning from RequestHandler. +// The returned bytes are valid until your request handler returns. func (ctx *RequestCtx) Referer() []byte { return ctx.Request.Header.Referer() } // UserAgent returns User-Agent header value from the request. +// +// The returned bytes are valid until your request handler returns. func (ctx *RequestCtx) UserAgent() []byte { return ctx.Request.Header.UserAgent() } // Path returns requested path. // -// The path is valid until returning from RequestHandler. +// The returned bytes are valid until your request handler returns. func (ctx *RequestCtx) Path() []byte { return ctx.URI().Path() } // Host returns requested host. // -// The host is valid until returning from RequestHandler. +// The returned bytes are valid until your request handler returns. func (ctx *RequestCtx) Host() []byte { return ctx.URI().Host() } @@ -649,9 +913,9 @@ // // It doesn't return POST'ed arguments - use PostArgs() for this. // -// Returned arguments are valid until returning from RequestHandler. -// // See also PostArgs, FormValue and FormFile. +// +// These args are valid until your request handler returns. func (ctx *RequestCtx) QueryArgs() *Args { return ctx.URI().QueryArgs() } @@ -660,9 +924,9 @@ // // It doesn't return query arguments from RequestURI - use QueryArgs for this. // -// Returned arguments are valid until returning from RequestHandler. -// // See also QueryArgs, FormValue and FormFile. +// +// These args are valid until your request handler returns. func (ctx *RequestCtx) PostArgs() *Args { return ctx.Request.PostArgs() } @@ -678,7 +942,7 @@ // // Use SaveMultipartFile function for permanently saving uploaded file. // -// The returned form is valid until returning from RequestHandler. +// The returned form is valid until your request handler returns. // // See also FormFile and FormValue. func (ctx *RequestCtx) MultipartForm() (*multipart.Form, error) { @@ -692,7 +956,7 @@ // // Use SaveMultipartFile function for permanently saving uploaded file. // -// The returned file header is valid until returning from RequestHandler. +// The returned file header is valid until your request handler returns. func (ctx *RequestCtx) FormFile(key string) (*multipart.FileHeader, error) { mf, err := ctx.MultipartForm() if err != nil { @@ -713,24 +977,53 @@ var ErrMissingFile = errors.New("there is no uploaded file associated with the given key") // SaveMultipartFile saves multipart file fh under the given filename path. -func SaveMultipartFile(fh *multipart.FileHeader, path string) error { - f, err := fh.Open() +func SaveMultipartFile(fh *multipart.FileHeader, path string) (err error) { + var ( + f multipart.File + ff *os.File + ) + f, err = fh.Open() if err != nil { - return err + return } - defer f.Close() - if ff, ok := f.(*os.File); ok { - return os.Rename(ff.Name(), path) + var ok bool + if ff, ok = f.(*os.File); ok { + // Windows can't rename files that are opened. + if err = f.Close(); err != nil { + return + } + + // If renaming fails we try the normal copying method. + // Renaming could fail if the files are on different devices. + if os.Rename(ff.Name(), path) == nil { + return nil + } + + // Reopen f for the code below. + if f, err = fh.Open(); err != nil { + return + } } - ff, err := os.Create(path) - if err != nil { - return err + defer func() { + e := f.Close() + if err == nil { + err = e + } + }() + + if ff, err = os.Create(path); err != nil { + return } - defer ff.Close() + defer func() { + e := ff.Close() + if err == nil { + err = e + } + }() _, err = copyZeroAlloc(ff, f) - return err + return } // FormValue returns form value associated with the given key. @@ -747,7 +1040,7 @@ // * MultipartForm for obtaining values from multipart form. // * FormFile for obtaining uploaded files. // -// The returned value is valid until returning from RequestHandler. +// The returned value is valid until your request handler returns. func (ctx *RequestCtx) FormValue(key string) []byte { v := ctx.QueryArgs().Peek(key) if len(v) > 0 { @@ -787,9 +1080,29 @@ return ctx.Request.Header.IsDelete() } +// IsConnect returns true if request method is CONNECT. +func (ctx *RequestCtx) IsConnect() bool { + return ctx.Request.Header.IsConnect() +} + +// IsOptions returns true if request method is OPTIONS. +func (ctx *RequestCtx) IsOptions() bool { + return ctx.Request.Header.IsOptions() +} + +// IsTrace returns true if request method is TRACE. +func (ctx *RequestCtx) IsTrace() bool { + return ctx.Request.Header.IsTrace() +} + +// IsPatch returns true if request method is PATCH. +func (ctx *RequestCtx) IsPatch() bool { + return ctx.Request.Header.IsPatch() +} + // Method return request method. // -// Returned value is valid until returning from RequestHandler. +// Returned value is valid until your request handler returns. func (ctx *RequestCtx) Method() []byte { return ctx.Request.Header.Method() } @@ -803,6 +1116,12 @@ // // Always returns non-nil result. func (ctx *RequestCtx) RemoteAddr() net.Addr { + if ctx.remoteAddr != nil { + return ctx.remoteAddr + } + if ctx.c == nil { + return zeroTCPAddr + } addr := ctx.c.RemoteAddr() if addr == nil { return zeroTCPAddr @@ -810,10 +1129,21 @@ return addr } +// SetRemoteAddr sets remote address to the given value. +// +// Set nil value to resore default behaviour for using +// connection remote address. +func (ctx *RequestCtx) SetRemoteAddr(remoteAddr net.Addr) { + ctx.remoteAddr = remoteAddr +} + // LocalAddr returns server address for the given request. // // Always returns non-nil result. func (ctx *RequestCtx) LocalAddr() net.Addr { + if ctx.c == nil { + return zeroTCPAddr + } addr := ctx.c.LocalAddr() if addr == nil { return zeroTCPAddr @@ -821,11 +1151,22 @@ return addr } -// RemoteIP returns client ip for the given request. +// RemoteIP returns the client ip the request came from. // // Always returns non-nil result. func (ctx *RequestCtx) RemoteIP() net.IP { - x, ok := ctx.RemoteAddr().(*net.TCPAddr) + return addrToIP(ctx.RemoteAddr()) +} + +// LocalIP returns the server ip the request came to. +// +// Always returns non-nil result. +func (ctx *RequestCtx) LocalIP() net.IP { + return addrToIP(ctx.LocalAddr()) +} + +func addrToIP(addr net.Addr) net.IP { + x, ok := addr.(*net.TCPAddr) if !ok { return net.IPv4zero } @@ -834,6 +1175,8 @@ // Error sets response status code to the given value and sets response body // to the given message. +// +// Warning: this will reset the response headers and body already set! func (ctx *RequestCtx) Error(msg string, statusCode int) { ctx.Response.Reset() ctx.SetStatusCode(statusCode) @@ -861,11 +1204,18 @@ // * StatusFound (302) // * StatusSeeOther (303) // * StatusTemporaryRedirect (307) +// * StatusPermanentRedirect (308) // // All other statusCode values are replaced by StatusFound (302). // // The redirect uri may be either absolute or relative to the current -// request uri. +// request uri. Fasthttp will always send an absolute uri back to the client. +// To send a relative uri you can use the following code: +// +// strLocation = []byte("Location") // Put this with your top level var () declarations. +// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri") +// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently) +// func (ctx *RequestCtx) Redirect(uri string, statusCode int) { u := AcquireURI() ctx.URI().CopyTo(u) @@ -883,11 +1233,18 @@ // * StatusFound (302) // * StatusSeeOther (303) // * StatusTemporaryRedirect (307) +// * StatusPermanentRedirect (308) // // All other statusCode values are replaced by StatusFound (302). // // The redirect uri may be either absolute or relative to the current -// request uri. +// request uri. Fasthttp will always send an absolute uri back to the client. +// To send a relative uri you can use the following code: +// +// strLocation = []byte("Location") // Put this with your top level var () declarations. +// ctx.Response.Header.SetCanonical(strLocation, "/relative?uri") +// ctx.Response.SetStatusCode(fasthttp.StatusMovedPermanently) +// func (ctx *RequestCtx) RedirectBytes(uri []byte, statusCode int) { s := b2s(uri) ctx.Redirect(s, statusCode) @@ -901,7 +1258,8 @@ func getRedirectStatusCode(statusCode int) int { if statusCode == StatusMovedPermanently || statusCode == StatusFound || - statusCode == StatusSeeOther || statusCode == StatusTemporaryRedirect { + statusCode == StatusSeeOther || statusCode == StatusTemporaryRedirect || + statusCode == StatusPermanentRedirect { return statusCode } return StatusFound @@ -990,7 +1348,7 @@ // PostBody returns POST request body. // -// The returned value is valid until RequestHandler return. +// The returned bytes are valid until your request handler returns. func (ctx *RequestCtx) PostBody() []byte { return ctx.Request.Body() } @@ -1040,7 +1398,7 @@ // It is safe re-using returned logger for logging multiple messages // for the current request. // -// The returned logger is valid until returning from RequestHandler. +// The returned logger is valid until your request handler returns. func (ctx *RequestCtx) Logger() Logger { if ctx.logger.ctx == nil { ctx.logger.ctx = ctx @@ -1098,15 +1456,87 @@ ctx.timeoutResponse = respCopy } +// NextProto adds nph to be processed when key is negotiated when TLS +// connection is established. +// +// This function can only be called before the server is started. +func (s *Server) NextProto(key string, nph ServeHandler) { + if s.nextProtos == nil { + s.nextProtos = make(map[string]ServeHandler) + } + s.configTLS() + s.tlsConfig.NextProtos = append(s.tlsConfig.NextProtos, key) + s.nextProtos[key] = nph +} + +func (s *Server) getNextProto(c net.Conn) (proto string, err error) { + if tlsConn, ok := c.(connTLSer); ok { + if s.ReadTimeout > 0 { + if err := c.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil { + panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err)) + } + } + + if s.WriteTimeout > 0 { + if err := c.SetWriteDeadline(time.Now().Add(s.WriteTimeout)); err != nil { + panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", s.WriteTimeout, err)) + } + } + + err = tlsConn.Handshake() + if err == nil { + proto = tlsConn.ConnectionState().NegotiatedProtocol + } + } + return +} + +// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted +// connections. It's used by ListenAndServe, ListenAndServeTLS and +// ListenAndServeTLSEmbed so dead TCP connections (e.g. closing laptop mid-download) +// eventually go away. +type tcpKeepaliveListener struct { + *net.TCPListener + keepalive bool + keepalivePeriod time.Duration +} + +func (ln tcpKeepaliveListener) Accept() (net.Conn, error) { + tc, err := ln.AcceptTCP() + if err != nil { + return nil, err + } + if err := tc.SetKeepAlive(ln.keepalive); err != nil { + tc.Close() //nolint:errcheck + return nil, err + } + if ln.keepalivePeriod > 0 { + if err := tc.SetKeepAlivePeriod(ln.keepalivePeriod); err != nil { + tc.Close() //nolint:errcheck + return nil, err + } + } + return tc, nil +} + // ListenAndServe serves HTTP requests from the given TCP4 addr. // // Pass custom listener to Serve if you need listening on non-TCP4 media // such as IPv6. +// +// Accepted connections are configured to enable TCP keep-alives. func (s *Server) ListenAndServe(addr string) error { ln, err := net.Listen("tcp4", addr) if err != nil { return err } + if tcpln, ok := ln.(*net.TCPListener); ok { + return s.Serve(tcpKeepaliveListener{ + TCPListener: tcpln, + keepalive: s.TCPKeepalive, + keepalivePeriod: s.TCPKeepalivePeriod, + }) + } return s.Serve(ln) } @@ -1135,11 +1565,23 @@ // // Pass custom listener to Serve if you need listening on non-TCP4 media // such as IPv6. +// +// If the certFile or keyFile has not been provided to the server structure, +// the function will use the previously added TLS configuration. +// +// Accepted connections are configured to enable TCP keep-alives. func (s *Server) ListenAndServeTLS(addr, certFile, keyFile string) error { ln, err := net.Listen("tcp4", addr) if err != nil { return err } + if tcpln, ok := ln.(*net.TCPListener); ok { + return s.ServeTLS(tcpKeepaliveListener{ + TCPListener: tcpln, + keepalive: s.TCPKeepalive, + keepalivePeriod: s.TCPKeepalivePeriod, + }, certFile, keyFile) + } return s.ServeTLS(ln, certFile, keyFile) } @@ -1149,59 +1591,129 @@ // // Pass custom listener to Serve if you need listening on arbitrary media // such as IPv6. +// +// If the certFile or keyFile has not been provided the server structure, +// the function will use previously added TLS configuration. +// +// Accepted connections are configured to enable TCP keep-alives. func (s *Server) ListenAndServeTLSEmbed(addr string, certData, keyData []byte) error { ln, err := net.Listen("tcp4", addr) if err != nil { return err } + if tcpln, ok := ln.(*net.TCPListener); ok { + return s.ServeTLSEmbed(tcpKeepaliveListener{ + TCPListener: tcpln, + keepalive: s.TCPKeepalive, + keepalivePeriod: s.TCPKeepalivePeriod, + }, certData, keyData) + } return s.ServeTLSEmbed(ln, certData, keyData) } // ServeTLS serves HTTPS requests from the given listener. // // certFile and keyFile are paths to TLS certificate and key files. +// +// If the certFile or keyFile has not been provided the server structure, +// the function will use previously added TLS configuration. func (s *Server) ServeTLS(ln net.Listener, certFile, keyFile string) error { - lnTLS, err := newTLSListener(ln, certFile, keyFile) - if err != nil { + s.mu.Lock() + err := s.AppendCert(certFile, keyFile) + if err != nil && err != errNoCertOrKeyProvided { + s.mu.Unlock() return err } - return s.Serve(lnTLS) + if s.tlsConfig == nil { + s.mu.Unlock() + return errNoCertOrKeyProvided + } + + // BuildNameToCertificate has been deprecated since 1.14. + // But since we also support older versions we'll keep this here. + s.tlsConfig.BuildNameToCertificate() //nolint:staticcheck + + s.mu.Unlock() + + return s.Serve( + tls.NewListener(ln, s.tlsConfig), + ) } // ServeTLSEmbed serves HTTPS requests from the given listener. // // certData and keyData must contain valid TLS certificate and key data. +// +// If the certFile or keyFile has not been provided the server structure, +// the function will use previously added TLS configuration. func (s *Server) ServeTLSEmbed(ln net.Listener, certData, keyData []byte) error { - lnTLS, err := newTLSListenerEmbed(ln, certData, keyData) - if err != nil { + s.mu.Lock() + + err := s.AppendCertEmbed(certData, keyData) + if err != nil && err != errNoCertOrKeyProvided { + s.mu.Unlock() return err } - return s.Serve(lnTLS) + if s.tlsConfig == nil { + s.mu.Unlock() + return errNoCertOrKeyProvided + } + + // BuildNameToCertificate has been deprecated since 1.14. + // But since we also support older versions we'll keep this here. + s.tlsConfig.BuildNameToCertificate() //nolint:staticcheck + + s.mu.Unlock() + + return s.Serve( + tls.NewListener(ln, s.tlsConfig), + ) } -func newTLSListener(ln net.Listener, certFile, keyFile string) (net.Listener, error) { +// AppendCert appends certificate and keyfile to TLS Configuration. +// +// This function allows programmer to handle multiple domains +// in one server structure. See examples/multidomain +func (s *Server) AppendCert(certFile, keyFile string) error { + if len(certFile) == 0 && len(keyFile) == 0 { + return errNoCertOrKeyProvided + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { - return nil, fmt.Errorf("cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) + return fmt.Errorf("cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) } - return newCertListener(ln, &cert), nil + + s.configTLS() + + s.tlsConfig.Certificates = append(s.tlsConfig.Certificates, cert) + return nil } -func newTLSListenerEmbed(ln net.Listener, certData, keyData []byte) (net.Listener, error) { +// AppendCertEmbed does the same as AppendCert but using in-memory data. +func (s *Server) AppendCertEmbed(certData, keyData []byte) error { + if len(certData) == 0 && len(keyData) == 0 { + return errNoCertOrKeyProvided + } + cert, err := tls.X509KeyPair(certData, keyData) if err != nil { - return nil, fmt.Errorf("cannot load TLS key pair from the provided certData(%d) and keyData(%d): %s", + return fmt.Errorf("cannot load TLS key pair from the provided certData(%d) and keyData(%d): %s", len(certData), len(keyData), err) } - return newCertListener(ln, &cert), nil + + s.configTLS() + + s.tlsConfig.Certificates = append(s.tlsConfig.Certificates, cert) + return nil } -func newCertListener(ln net.Listener, cert *tls.Certificate) net.Listener { - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{*cert}, - PreferServerCipherSuites: true, +func (s *Server) configTLS() { + if s.tlsConfig == nil { + s.tlsConfig = &tls.Config{ + PreferServerCipherSuites: true, + } } - return tls.NewListener(ln, tlsConfig) } // DefaultConcurrency is the maximum number of concurrent connections @@ -1218,15 +1730,36 @@ var err error maxWorkersCount := s.getConcurrency() - s.concurrencyCh = make(chan struct{}, maxWorkersCount) + + s.mu.Lock() + { + s.ln = append(s.ln, ln) + if s.done == nil { + s.done = make(chan struct{}) + } + + if s.concurrencyCh == nil { + s.concurrencyCh = make(chan struct{}, maxWorkersCount) + } + } + s.mu.Unlock() + wp := &workerPool{ WorkerFunc: s.serveConn, MaxWorkersCount: maxWorkersCount, LogAllErrors: s.LogAllErrors, Logger: s.logger(), + connState: s.setState, } wp.Start() + // Count our waiting to accept a connection as an open connection. + // This way we can't get into any weird state where just after accepting + // a connection Shutdown is called which reads open as 0 because it isn't + // incremented yet. + atomic.AddInt32(&s.open, 1) + defer atomic.AddInt32(&s.open, -1) + for { if c, err = acceptConn(s, ln, &lastPerIPErrorTime); err != nil { wp.Stop() @@ -1235,10 +1768,14 @@ } return err } + s.setState(c, StateNew) + atomic.AddInt32(&s.open, 1) if !wp.Serve(c) { + atomic.AddInt32(&s.open, -1) s.writeFastError(c, StatusServiceUnavailable, "The connection cannot be served because Server.Concurrency limit exceeded") c.Close() + s.setState(c, StateClosed) if time.Since(lastOverflowErrorTime) > time.Minute { s.logger().Printf("The incoming connection cannot be served, because %d concurrent connections are served. "+ "Try increasing Server.Concurrency", maxWorkersCount) @@ -1251,12 +1788,62 @@ // // There is a hope other servers didn't reach their // concurrency limits yet :) - time.Sleep(100 * time.Millisecond) + // + // See also: https://github.com/valyala/fasthttp/pull/485#discussion_r239994990 + if s.SleepWhenConcurrencyLimitsExceeded > 0 { + time.Sleep(s.SleepWhenConcurrencyLimitsExceeded) + } } c = nil } } +// Shutdown gracefully shuts down the server without interrupting any active connections. +// Shutdown works by first closing all open listeners and then waiting indefinitely for all connections to return to idle and then shut down. +// +// When Shutdown is called, Serve, ListenAndServe, and ListenAndServeTLS immediately return nil. +// Make sure the program doesn't exit and waits instead for Shutdown to return. +// +// Shutdown does not close keepalive connections so its recommended to set ReadTimeout and IdleTimeout to something else than 0. +func (s *Server) Shutdown() error { + s.mu.Lock() + defer s.mu.Unlock() + + atomic.StoreInt32(&s.stop, 1) + defer atomic.StoreInt32(&s.stop, 0) + + if s.ln == nil { + return nil + } + + for _, ln := range s.ln { + if err := ln.Close(); err != nil { + return err + } + } + + if s.done != nil { + close(s.done) + } + + // Closing the listener will make Serve() call Stop on the worker pool. + // Setting .stop to 1 will make serveConn() break out of its loop. + // Now we just have to wait until all workers are done. + for { + if open := atomic.LoadInt32(&s.open); open == 0 { + break + } + // This is not an optimal solution but using a sync.WaitGroup + // here causes data races as it's hard to prevent Add() to be called + // while Wait() is waiting. + time.Sleep(time.Millisecond * 100) + } + + s.done = nil + s.ln = nil + return nil +} + func acceptConn(s *Server, ln net.Listener, lastPerIPErrorTime *time.Time) (net.Conn, error) { for { c, err := ln.Accept() @@ -1324,12 +1911,8 @@ ErrPerIPConnLimit = errors.New("too many connections per ip") // ErrConcurrencyLimit may be returned from ServeConn if the number - // of concurrenty served connections exceeds Server.Concurrency. - ErrConcurrencyLimit = errors.New("canot serve the connection because Server.Concurrency concurrent connections are served") - - // ErrKeepaliveTimeout is returned from ServeConn - // if the connection lifetime exceeds MaxKeepaliveDuration. - ErrKeepaliveTimeout = errors.New("exceeded MaxKeepaliveDuration") + // of concurrently served connections exceeds Server.Concurrency. + ErrConcurrencyLimit = errors.New("cannot serve the connection because Server.Concurrency concurrent connections are served") ) // ServeConn serves HTTP requests from the given connection. @@ -1358,23 +1941,50 @@ return ErrConcurrencyLimit } + atomic.AddInt32(&s.open, 1) + err := s.serveConn(c) atomic.AddUint32(&s.concurrency, ^uint32(0)) if err != errHijacked { err1 := c.Close() + s.setState(c, StateClosed) if err == nil { err = err1 } } else { err = nil + s.setState(c, StateHijacked) } return err } var errHijacked = errors.New("connection has been hijacked") +// GetCurrentConcurrency returns a number of currently served +// connections. +// +// This function is intended be used by monitoring systems +func (s *Server) GetCurrentConcurrency() uint32 { + return atomic.LoadUint32(&s.concurrency) +} + +// GetOpenConnectionsCount returns a number of opened connections. +// +// This function is intended be used by monitoring systems +func (s *Server) GetOpenConnectionsCount() int32 { + if atomic.LoadInt32(&s.stop) == 0 { + // Decrement by one to avoid reporting the extra open value that gets + // counted while the server is listening. + return atomic.LoadInt32(&s.open) - 1 + } + // This is not perfect, because s.stop could have changed to zero + // before we load the value of s.open. However, in the common case + // this avoids underreporting open connections by 1 during server shutdown. + return atomic.LoadInt32(&s.open) +} + func (s *Server) getConcurrency() int { n := s.Concurrency if n <= 0 { @@ -1389,141 +1999,322 @@ return atomic.AddUint64(&globalConnID, 1) } -func (s *Server) serveConn(c net.Conn) error { - serverName := s.getServerName() +// DefaultMaxRequestBodySize is the maximum request body size the server +// reads by default. +// +// See Server.MaxRequestBodySize for details. +const DefaultMaxRequestBodySize = 4 * 1024 * 1024 + +func (s *Server) idleTimeout() time.Duration { + if s.IdleTimeout != 0 { + return s.IdleTimeout + } + return s.ReadTimeout +} + +func (s *Server) serveConnCleanup() { + atomic.AddInt32(&s.open, -1) + atomic.AddUint32(&s.concurrency, ^uint32(0)) +} + +func (s *Server) serveConn(c net.Conn) (err error) { + defer s.serveConnCleanup() + atomic.AddUint32(&s.concurrency, 1) + + var proto string + if proto, err = s.getNextProto(c); err != nil { + return + } + if handler, ok := s.nextProtos[proto]; ok { + // Remove read or write deadlines that might have previously been set. + // The next handler is responsible for setting its own deadlines. + if s.ReadTimeout > 0 || s.WriteTimeout > 0 { + if err := c.SetDeadline(zeroTime); err != nil { + panic(fmt.Sprintf("BUG: error in SetDeadline(zeroTime): %s", err)) + } + } + + return handler(c) + } + + var serverName []byte + if !s.NoDefaultServerHeader { + serverName = s.getServerName() + } connRequestNum := uint64(0) connID := nextConnID() - currentTime := time.Now() - connTime := currentTime + connTime := time.Now() + maxRequestBodySize := s.MaxRequestBodySize + if maxRequestBodySize <= 0 { + maxRequestBodySize = DefaultMaxRequestBodySize + } + writeTimeout := s.WriteTimeout + previousWriteTimeout := time.Duration(0) ctx := s.acquireCtx(c) ctx.connTime = connTime + isTLS := ctx.IsTLS() var ( br *bufio.Reader bw *bufio.Writer - err error - timeoutResponse *Response - hijackHandler HijackHandler - - lastReadDeadlineTime time.Time - lastWriteDeadlineTime time.Time + timeoutResponse *Response + hijackHandler HijackHandler + hijackNoResponse bool connectionClose bool isHTTP11 bool + + reqReset bool + continueReadingRequest bool = true ) for { connRequestNum++ - ctx.time = currentTime - if s.ReadTimeout > 0 || s.MaxKeepaliveDuration > 0 { - lastReadDeadlineTime = s.updateReadDeadline(c, ctx, lastReadDeadlineTime) - if lastReadDeadlineTime.IsZero() { - err = ErrKeepaliveTimeout - break + // If this is a keep-alive connection set the idle timeout. + if connRequestNum > 1 { + if d := s.idleTimeout(); d > 0 { + if err := c.SetReadDeadline(time.Now().Add(d)); err != nil { + panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", d, err)) + } } } - if !(s.ReduceMemoryUsage || ctx.lastReadDuration > time.Second) || br != nil { + if !s.ReduceMemoryUsage || br != nil { if br == nil { br = acquireReader(ctx) } + + // If this is a keep-alive connection we want to try and read the first bytes + // within the idle time. + if connRequestNum > 1 { + var b []byte + b, err = br.Peek(1) + if len(b) == 0 { + // If reading from a keep-alive connection returns nothing it means + // the connection was closed (either timeout or from the other side). + if err != io.EOF { + err = ErrNothingRead{err} + } + } + } } else { + // If this is a keep-alive connection acquireByteReader will try to peek + // a couple of bytes already so the idle timeout will already be used. br, err = acquireByteReader(&ctx) } + ctx.Request.isTLS = isTLS + ctx.Response.Header.noDefaultContentType = s.NoDefaultContentType + ctx.Response.Header.noDefaultDate = s.NoDefaultDate + + // Secure header error logs configuration + ctx.Request.Header.secureErrorLogMessage = s.SecureErrorLogMessage + ctx.Response.Header.secureErrorLogMessage = s.SecureErrorLogMessage + ctx.Request.secureErrorLogMessage = s.SecureErrorLogMessage + ctx.Response.secureErrorLogMessage = s.SecureErrorLogMessage + if err == nil { + if s.ReadTimeout > 0 { + if err := c.SetReadDeadline(time.Now().Add(s.ReadTimeout)); err != nil { + panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", s.ReadTimeout, err)) + } + } else if s.IdleTimeout > 0 && connRequestNum > 1 { + // If this was an idle connection and the server has an IdleTimeout but + // no ReadTimeout then we should remove the ReadTimeout. + if err := c.SetReadDeadline(zeroTime); err != nil { + panic(fmt.Sprintf("BUG: error in SetReadDeadline(zeroTime): %s", err)) + } + } if s.DisableHeaderNamesNormalizing { ctx.Request.Header.DisableNormalizing() ctx.Response.Header.DisableNormalizing() } - err = ctx.Request.readLimitBody(br, s.MaxRequestBodySize, s.GetOnly) - if br.Buffered() == 0 || err != nil { + + // Reading Headers. + // + // If we have pipline response in the outgoing buffer, + // we only want to try and read the next headers once. + // If we have to wait for the next request we flush the + // outgoing buffer first so it doesn't have to wait. + if bw != nil && bw.Buffered() > 0 { + err = ctx.Request.Header.readLoop(br, false) + if err == errNeedMore { + err = bw.Flush() + if err != nil { + break + } + + err = ctx.Request.Header.Read(br) + } + } else { + err = ctx.Request.Header.Read(br) + } + + if err == nil { + if onHdrRecv := s.HeaderReceived; onHdrRecv != nil { + reqConf := onHdrRecv(&ctx.Request.Header) + if reqConf.ReadTimeout > 0 { + deadline := time.Now().Add(reqConf.ReadTimeout) + if err := c.SetReadDeadline(deadline); err != nil { + panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", deadline, err)) + } + } + if reqConf.MaxRequestBodySize > 0 { + maxRequestBodySize = reqConf.MaxRequestBodySize + } + if reqConf.WriteTimeout > 0 { + writeTimeout = reqConf.WriteTimeout + } + } + //read body + if s.StreamRequestBody { + err = ctx.Request.readBodyStream(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) + } else { + err = ctx.Request.readLimitBody(br, maxRequestBodySize, s.GetOnly, !s.DisablePreParseMultipartForm) + } + } + + if err == nil { + // If we read any bytes off the wire, we're active. + s.setState(c, StateActive) + } + + if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil { releaseReader(s, br) br = nil } } - currentTime = time.Now() - ctx.lastReadDuration = currentTime.Sub(ctx.time) - if err != nil { if err == io.EOF { err = nil + } else if nr, ok := err.(ErrNothingRead); ok { + if connRequestNum > 1 { + // This is not the first request and we haven't read a single byte + // of a new request yet. This means it's just a keep-alive connection + // closing down either because the remote closed it or because + // or a read timeout on our side. Either way just close the connection + // and don't return any error response. + err = nil + } else { + err = nr.error + } + } + + if err != nil { + bw = s.writeErrorResponse(bw, ctx, serverName, err) } break } // 'Expect: 100-continue' request handling. - // See http://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html for details. - if !ctx.Request.Header.noBody() && ctx.Request.MayContinue() { - // Send 'HTTP/1.1 100 Continue' response. - if bw == nil { - bw = acquireWriter(ctx) - } - bw.Write(strResponseContinue) - err = bw.Flush() - releaseWriter(s, bw) - bw = nil - if err != nil { - break - } + // See https://www.w3.org/Protocols/rfc2616/rfc2616-sec8.html#sec8.2.3 for details. + if ctx.Request.MayContinue() { - // Read request body. - if br == nil { - br = acquireReader(ctx) - } - err = ctx.Request.ContinueReadBody(br, s.MaxRequestBodySize) - if br.Buffered() == 0 || err != nil { - releaseReader(s, br) - br = nil + // Allow the ability to deny reading the incoming request body + if s.ContinueHandler != nil { + if continueReadingRequest = s.ContinueHandler(&ctx.Request.Header); !continueReadingRequest { + if br != nil { + br.Reset(ctx.c) + } + + ctx.SetStatusCode(StatusExpectationFailed) + } } - if err != nil { - break + + if continueReadingRequest { + if bw == nil { + bw = acquireWriter(ctx) + } + + // Send 'HTTP/1.1 100 Continue' response. + _, err = bw.Write(strResponseContinue) + if err != nil { + break + } + err = bw.Flush() + if err != nil { + break + } + if s.ReduceMemoryUsage { + releaseWriter(s, bw) + bw = nil + } + + // Read request body. + if br == nil { + br = acquireReader(ctx) + } + + if s.StreamRequestBody { + err = ctx.Request.ContinueReadBodyStream(br, maxRequestBodySize, !s.DisablePreParseMultipartForm) + } else { + err = ctx.Request.ContinueReadBody(br, maxRequestBodySize, !s.DisablePreParseMultipartForm) + } + if (s.ReduceMemoryUsage && br.Buffered() == 0) || err != nil { + releaseReader(s, br) + br = nil + } + if err != nil { + bw = s.writeErrorResponse(bw, ctx, serverName, err) + break + } } } - connectionClose = s.DisableKeepalive || ctx.Request.Header.connectionCloseFast() + connectionClose = s.DisableKeepalive || ctx.Request.Header.ConnectionClose() isHTTP11 = ctx.Request.Header.IsHTTP11() - ctx.Response.Header.SetServerBytes(serverName) + if serverName != nil { + ctx.Response.Header.SetServerBytes(serverName) + } ctx.connID = connID ctx.connRequestNum = connRequestNum - ctx.connTime = connTime - ctx.time = currentTime - s.Handler(ctx) + ctx.time = time.Now() + + // If a client denies a request the handler should not be called + if continueReadingRequest { + s.Handler(ctx) + } timeoutResponse = ctx.timeoutResponse if timeoutResponse != nil { + // Acquire a new ctx because the old one will still be in use by the timeout out handler. ctx = s.acquireCtx(c) timeoutResponse.CopyTo(&ctx.Response) - if br != nil { - // Close connection, since br may be attached to the old ctx via ctx.fbr. - ctx.SetConnectionClose() - } } if !ctx.IsGet() && ctx.IsHead() { ctx.Response.SkipBody = true } + reqReset = true ctx.Request.Reset() hijackHandler = ctx.hijackHandler ctx.hijackHandler = nil - - ctx.userValues.Reset() + hijackNoResponse = ctx.hijackNoResponse && hijackHandler != nil + ctx.hijackNoResponse = false if s.MaxRequestsPerConn > 0 && connRequestNum >= uint64(s.MaxRequestsPerConn) { ctx.SetConnectionClose() } - if s.WriteTimeout > 0 || s.MaxKeepaliveDuration > 0 { - lastWriteDeadlineTime = s.updateWriteDeadline(c, ctx, lastWriteDeadlineTime) + if writeTimeout > 0 { + if err := c.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", writeTimeout, err)) + } + previousWriteTimeout = writeTimeout + } else if previousWriteTimeout > 0 { + // We don't want a write timeout but we previously set one, remove it. + if err := c.SetWriteDeadline(zeroTime); err != nil { + panic(fmt.Sprintf("BUG: error in SetWriteDeadline(zeroTime): %s", err)) + } + previousWriteTimeout = 0 } - // Verify Request.Header.connectionCloseFast() again, - // since request handler might trigger full headers' parsing. - connectionClose = connectionClose || ctx.Request.Header.connectionCloseFast() || ctx.Response.ConnectionClose() + connectionClose = connectionClose || ctx.Response.ConnectionClose() + connectionClose = connectionClose || ctx.Response.ConnectionClose() || (s.CloseOnShutdown && atomic.LoadInt32(&s.stop) == 1) if connectionClose { ctx.Response.Header.SetCanonical(strConnection, strClose) } else if !isHTTP11 { @@ -1533,56 +2324,77 @@ ctx.Response.Header.SetCanonical(strConnection, strKeepAlive) } - if len(ctx.Response.Header.Server()) == 0 { + if serverName != nil && len(ctx.Response.Header.Server()) == 0 { ctx.Response.Header.SetServerBytes(serverName) } - if bw == nil { - bw = acquireWriter(ctx) - } - if err = writeResponse(ctx, bw); err != nil { - break - } - - if br == nil || connectionClose { - err = bw.Flush() - releaseWriter(s, bw) - bw = nil - if err != nil { + if !hijackNoResponse { + if bw == nil { + bw = acquireWriter(ctx) + } + if err = writeResponse(ctx, bw); err != nil { break } + + // Only flush the writer if we don't have another request in the pipeline. + // This is a big of an ugly optimization for https://www.techempower.com/benchmarks/ + // This benchmark will send 16 pipelined requests. It is faster to pack as many responses + // in a TCP packet and send it back at once than waiting for a flush every request. + // In real world circumstances this behaviour could be argued as being wrong. + if br == nil || br.Buffered() == 0 || connectionClose { + err = bw.Flush() + if err != nil { + break + } + } if connectionClose { break } + if s.ReduceMemoryUsage && hijackHandler == nil { + releaseWriter(s, bw) + bw = nil + } } if hijackHandler != nil { - var hjr io.Reader - hjr = c + var hjr io.Reader = c if br != nil { hjr = br br = nil - // br may point to ctx.fbr, so do not return ctx into pool. - ctx = s.acquireCtx(c) + // br may point to ctx.fbr, so do not return ctx into pool below. + ctx = nil } if bw != nil { err = bw.Flush() - releaseWriter(s, bw) - bw = nil if err != nil { break } + releaseWriter(s, bw) + bw = nil + } + err = c.SetDeadline(zeroTime) + if err != nil { + break } - c.SetReadDeadline(zeroTime) - c.SetWriteDeadline(zeroTime) go hijackConnHandler(hjr, c, s, hijackHandler) - hijackHandler = nil err = errHijacked break } - currentTime = time.Now() + if ctx.Request.bodyStream != nil { + if rs, ok := ctx.Request.bodyStream.(*requestStream); ok { + releaseRequestStream(rs) + } + } + + s.setState(c, StateIdle) + ctx.userValues.Reset() + + if atomic.LoadInt32(&s.stop) == 1 { + err = nil + break + } } if br != nil { @@ -1591,79 +2403,35 @@ if bw != nil { releaseWriter(s, bw) } - s.releaseCtx(ctx) - return err -} - -func (s *Server) updateReadDeadline(c net.Conn, ctx *RequestCtx, lastDeadlineTime time.Time) time.Time { - readTimeout := s.ReadTimeout - currentTime := ctx.time - if s.MaxKeepaliveDuration > 0 { - connTimeout := s.MaxKeepaliveDuration - currentTime.Sub(ctx.connTime) - if connTimeout <= 0 { - return zeroTime - } - if connTimeout < readTimeout { - readTimeout = connTimeout + if ctx != nil { + // in unexpected cases the for loop will break + // before request reset call. in such cases, call it before + // release to fix #548 + if !reqReset { + ctx.Request.Reset() } + s.releaseCtx(ctx) } - - // Optimization: update read deadline only if more than 25% - // of the last read deadline exceeded. - // See https://github.com/golang/go/issues/15133 for details. - if currentTime.Sub(lastDeadlineTime) > (readTimeout >> 2) { - if err := c.SetReadDeadline(currentTime.Add(readTimeout)); err != nil { - panic(fmt.Sprintf("BUG: error in SetReadDeadline(%s): %s", readTimeout, err)) - } - lastDeadlineTime = currentTime - } - return lastDeadlineTime + return } -func (s *Server) updateWriteDeadline(c net.Conn, ctx *RequestCtx, lastDeadlineTime time.Time) time.Time { - writeTimeout := s.WriteTimeout - if s.MaxKeepaliveDuration > 0 { - connTimeout := s.MaxKeepaliveDuration - time.Since(ctx.connTime) - if connTimeout <= 0 { - // MaxKeepAliveDuration exceeded, but let's try sending response anyway - // in 100ms with 'Connection: close' header. - ctx.SetConnectionClose() - connTimeout = 100 * time.Millisecond - } - if connTimeout < writeTimeout { - writeTimeout = connTimeout - } +func (s *Server) setState(nc net.Conn, state ConnState) { + if hook := s.ConnState; hook != nil { + hook(nc, state) } - - // Optimization: update write deadline only if more than 25% - // of the last write deadline exceeded. - // See https://github.com/golang/go/issues/15133 for details. - currentTime := time.Now() - if currentTime.Sub(lastDeadlineTime) > (writeTimeout >> 2) { - if err := c.SetWriteDeadline(currentTime.Add(writeTimeout)); err != nil { - panic(fmt.Sprintf("BUG: error in SetWriteDeadline(%s): %s", writeTimeout, err)) - } - lastDeadlineTime = currentTime - } - return lastDeadlineTime } func hijackConnHandler(r io.Reader, c net.Conn, s *Server, h HijackHandler) { hjc := s.acquireHijackConn(r, c) + h(hjc) - defer func() { - if r := recover(); r != nil { - s.logger().Printf("panic on hijacked conn: %s\nStack trace:\n%s", r, debug.Stack()) - } - - if br, ok := r.(*bufio.Reader); ok { - releaseReader(s, br) - } + if br, ok := r.(*bufio.Reader); ok { + releaseReader(s, br) + } + if !s.KeepHijackedConns { c.Close() s.releaseHijackConn(hjc) - }() - - h(hjc) + } } func (s *Server) acquireHijackConn(r io.Reader, c net.Conn) *hijackConn { @@ -1672,6 +2440,7 @@ hjc := &hijackConn{ Conn: c, r: r, + s: s, } return hjc } @@ -1690,15 +2459,27 @@ type hijackConn struct { net.Conn r io.Reader + s *Server } -func (c hijackConn) Read(p []byte) (int, error) { +func (c *hijackConn) UnsafeConn() net.Conn { + return c.Conn +} + +func (c *hijackConn) Read(p []byte) (int, error) { return c.r.Read(p) } -func (c hijackConn) Close() error { - // hijacked conn is closed in hijackConnHandler. - return nil +func (c *hijackConn) Close() error { + if !c.s.KeepHijackedConns { + // when we do not keep hijacked connections, + // it is closed in hijackConnHandler. + return nil + } + + conn := c.Conn + c.s.releaseHijackConn(c) + return conn.Close() } // LastTimeoutErrorResponse returns the last timeout response set @@ -1727,7 +2508,6 @@ ctx := *ctxP s := ctx.s c := ctx.c - t := ctx.time s.releaseCtx(ctx) // Make GC happy, so it could garbage collect ctx @@ -1735,16 +2515,10 @@ ctx = nil *ctxP = nil - v := s.bytePool.Get() - if v == nil { - v = make([]byte, 1) - } - b := v.([]byte) - n, err := c.Read(b) - ch := b[0] - s.bytePool.Put(v) + var b [1]byte + n, err := c.Read(b[:]) + ctx = s.acquireCtx(c) - ctx.time = t *ctxP = ctx if err != nil { // Treat all errors as EOF on unsuccessful read @@ -1756,7 +2530,7 @@ } ctx.fbr.c = c - ctx.fbr.ch = ch + ctx.fbr.ch = b[0] ctx.fbr.byteRead = false r := acquireReader(ctx) r.Reset(&ctx.fbr) @@ -1799,17 +2573,40 @@ s.writerPool.Put(w) } -func (s *Server) acquireCtx(c net.Conn) *RequestCtx { +func (s *Server) acquireCtx(c net.Conn) (ctx *RequestCtx) { v := s.ctxPool.Get() - var ctx *RequestCtx if v == nil { - v = &RequestCtx{ + ctx = &RequestCtx{ s: s, } + keepBodyBuffer := !s.ReduceMemoryUsage + ctx.Request.keepBodyBuffer = keepBodyBuffer + ctx.Response.keepBodyBuffer = keepBodyBuffer + } else { + ctx = v.(*RequestCtx) } - ctx = v.(*RequestCtx) ctx.c = c - return ctx + return +} + +// Init2 prepares ctx for passing to RequestHandler. +// +// conn is used only for determining local and remote addresses. +// +// This function is intended for custom Server implementations. +// See https://github.com/valyala/httpteleport for details. +func (ctx *RequestCtx) Init2(conn net.Conn, logger Logger, reduceMemoryUsage bool) { + ctx.c = conn + ctx.remoteAddr = nil + ctx.logger.logger = logger + ctx.connID = nextConnID() + ctx.s = fakeServer + ctx.connRequestNum = 0 + ctx.connTime = time.Now() + + keepBodyBuffer := !reduceMemoryUsage + ctx.Request.keepBodyBuffer = keepBodyBuffer + ctx.Response.keepBodyBuffer = keepBodyBuffer } // Init prepares ctx for passing to RequestHandler. @@ -1821,35 +2618,79 @@ if remoteAddr == nil { remoteAddr = zeroTCPAddr } - ctx.c = &fakeAddrer{ - addr: remoteAddr, + c := &fakeAddrer{ + laddr: zeroTCPAddr, + raddr: remoteAddr, } if logger == nil { logger = defaultLogger } - ctx.connID = nextConnID() - ctx.logger.logger = logger - ctx.s = &fakeServer + ctx.Init2(c, logger, true) req.CopyTo(&ctx.Request) - ctx.Response.Reset() - ctx.connRequestNum = 0 - ctx.connTime = time.Now() - ctx.time = ctx.connTime } -var fakeServer Server +// Deadline returns the time when work done on behalf of this context +// should be canceled. Deadline returns ok==false when no deadline is +// set. Successive calls to Deadline return the same results. +// +// This method always returns 0, false and is only present to make +// RequestCtx implement the context interface. +func (ctx *RequestCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +// Done returns a channel that's closed when work done on behalf of this +// context should be canceled. Done may return nil if this context can +// never be canceled. Successive calls to Done return the same value. +func (ctx *RequestCtx) Done() <-chan struct{} { + return ctx.s.done +} + +// Err returns a non-nil error value after Done is closed, +// successive calls to Err return the same error. +// If Done is not yet closed, Err returns nil. +// If Done is closed, Err returns a non-nil error explaining why: +// Canceled if the context was canceled (via server Shutdown) +// or DeadlineExceeded if the context's deadline passed. +func (ctx *RequestCtx) Err() error { + select { + case <-ctx.s.done: + return context.Canceled + default: + return nil + } +} + +// Value returns the value associated with this context for key, or nil +// if no value is associated with key. Successive calls to Value with +// the same key returns the same result. +// +// This method is present to make RequestCtx implement the context interface. +// This method is the same as calling ctx.UserValue(key) +func (ctx *RequestCtx) Value(key interface{}) interface{} { + if keyString, ok := key.(string); ok { + return ctx.UserValue(keyString) + } + return nil +} + +var fakeServer = &Server{ + // Initialize concurrencyCh for TimeoutHandler + concurrencyCh: make(chan struct{}, DefaultConcurrency), +} type fakeAddrer struct { net.Conn - addr net.Addr + laddr net.Addr + raddr net.Addr } func (fa *fakeAddrer) RemoteAddr() net.Addr { - return fa.addr + return fa.raddr } func (fa *fakeAddrer) LocalAddr() net.Addr { - return fa.addr + return fa.laddr } func (fa *fakeAddrer) Read(p []byte) (int, error) { @@ -1869,7 +2710,9 @@ panic("BUG: cannot release timed out RequestCtx") } ctx.c = nil + ctx.remoteAddr = nil ctx.fbr.c = nil + ctx.userValues.Reset() s.ctxPool.Put(ctx) } @@ -1889,13 +2732,107 @@ } func (s *Server) writeFastError(w io.Writer, statusCode int, msg string) { - w.Write(statusLine(statusCode)) + w.Write(statusLine(statusCode)) //nolint:errcheck + + server := "" + if !s.NoDefaultServerHeader { + server = fmt.Sprintf("Server: %s\r\n", s.getServerName()) + } + + date := "" + if !s.NoDefaultDate { + serverDateOnce.Do(updateServerDate) + date = fmt.Sprintf("Date: %s\r\n", serverDate.Load()) + } + fmt.Fprintf(w, "Connection: close\r\n"+ - "Server: %s\r\n"+ - "Date: %s\r\n"+ + server+ + date+ "Content-Type: text/plain\r\n"+ "Content-Length: %d\r\n"+ "\r\n"+ "%s", - s.getServerName(), serverDate.Load(), len(msg), msg) + len(msg), msg) +} + +func defaultErrorHandler(ctx *RequestCtx, err error) { + if _, ok := err.(*ErrSmallBuffer); ok { + ctx.Error("Too big request header", StatusRequestHeaderFieldsTooLarge) + } else if netErr, ok := err.(*net.OpError); ok && netErr.Timeout() { + ctx.Error("Request timeout", StatusRequestTimeout) + } else { + ctx.Error("Error when parsing request", StatusBadRequest) + } +} + +func (s *Server) writeErrorResponse(bw *bufio.Writer, ctx *RequestCtx, serverName []byte, err error) *bufio.Writer { + errorHandler := defaultErrorHandler + if s.ErrorHandler != nil { + errorHandler = s.ErrorHandler + } + + errorHandler(ctx, err) + + if serverName != nil { + ctx.Response.Header.SetServerBytes(serverName) + } + ctx.SetConnectionClose() + if bw == nil { + bw = acquireWriter(ctx) + } + writeResponse(ctx, bw) //nolint:errcheck + bw.Flush() + return bw +} + +// A ConnState represents the state of a client connection to a server. +// It's used by the optional Server.ConnState hook. +type ConnState int + +const ( + // StateNew represents a new connection that is expected to + // send a request immediately. Connections begin at this + // state and then transition to either StateActive or + // StateClosed. + StateNew ConnState = iota + + // StateActive represents a connection that has read 1 or more + // bytes of a request. The Server.ConnState hook for + // StateActive fires before the request has entered a handler + // and doesn't fire again until the request has been + // handled. After the request is handled, the state + // transitions to StateClosed, StateHijacked, or StateIdle. + // For HTTP/2, StateActive fires on the transition from zero + // to one active request, and only transitions away once all + // active requests are complete. That means that ConnState + // cannot be used to do per-request work; ConnState only notes + // the overall state of the connection. + StateActive + + // StateIdle represents a connection that has finished + // handling a request and is in the keep-alive state, waiting + // for a new request. Connections transition from StateIdle + // to either StateActive or StateClosed. + StateIdle + + // StateHijacked represents a hijacked connection. + // This is a terminal state. It does not transition to StateClosed. + StateHijacked + + // StateClosed represents a closed connection. + // This is a terminal state. Hijacked connections do not + // transition to StateClosed. + StateClosed +) + +var stateName = map[ConnState]string{ + StateNew: "new", + StateActive: "active", + StateIdle: "idle", + StateHijacked: "hijacked", + StateClosed: "closed", +} + +func (c ConnState) String() string { + return stateName[c] } diff -Nru golang-github-valyala-fasthttp-20160617/server_test.go golang-github-valyala-fasthttp-1.31.0/server_test.go --- golang-github-valyala-fasthttp-20160617/server_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/server_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -3,12 +3,15 @@ import ( "bufio" "bytes" + "context" "crypto/tls" "fmt" "io" "io/ioutil" + "mime/multipart" "net" "os" + "reflect" "strings" "sync" "testing" @@ -17,208 +20,229 @@ "github.com/valyala/fasthttp/fasthttputil" ) -func TestRequestCtxRedirect(t *testing.T) { - testRequestCtxRedirect(t, "http://qqq/", "", "http://qqq/") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "", "http://qqq/foo/bar?baz=111") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "#aaa", "http://qqq/foo/bar?baz=111#aaa") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f", "http://qqq/foo/bar?abc=de&f") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f#sf", "http://qqq/foo/bar?abc=de&f#sf") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html", "http://qqq/foo/x.html") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?a=1", "http://qqq/foo/x.html?a=1") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html#aaa=bbb&cc=ddd", "http://qqq/foo/x.html#aaa=bbb&cc=ddd") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?b=1#aaa=bbb&cc=ddd", "http://qqq/foo/x.html?b=1#aaa=bbb&cc=ddd") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html", "http://qqq/x.html") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html#aaa=bbb&cc=ddd", "http://qqq/x.html#aaa=bbb&cc=ddd") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../x.html", "http://qqq/x.html") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../../x.html", "http://qqq/x.html") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "./.././../x.html", "http://qqq/x.html") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "http://foo.bar/baz", "http://foo.bar/baz") - testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "https://foo.bar/baz", "https://foo.bar/baz") -} +// Make sure RequestCtx implements context.Context +var _ context.Context = &RequestCtx{} -func testRequestCtxRedirect(t *testing.T, origURL, redirectURL, expectedURL string) { - var ctx RequestCtx - var req Request - req.SetRequestURI(origURL) - ctx.Init(&req, nil, nil) +func TestServerCRNLAfterPost_Pipeline(t *testing.T) { + t.Parallel() - ctx.Redirect(redirectURL, StatusFound) - loc := ctx.Response.Header.Peek("Location") - if string(loc) != expectedURL { - t.Fatalf("unexpected redirect url %q. Expecting %q. origURL=%q, redirectURL=%q", loc, expectedURL, origURL, redirectURL) + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + Logger: &testLogger{}, + } + + ln := fasthttputil.NewInmemoryListener() + defer ln.Close() + + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + }() + + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer c.Close() + if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" + + "\r\n\r\n" + // <-- this stuff is bogus, but we'll ignore it + "GET / HTTP/1.1\r\nHost: golang.org\r\n\r\n")); err != nil { + t.Fatal(err) + } + + br := bufio.NewReader(c) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } } -func TestServerResponseServerHeader(t *testing.T) { - serverName := "foobar serv" +func TestServerCRNLAfterPost(t *testing.T) { + t.Parallel() s := &Server{ Handler: func(ctx *RequestCtx) { - name := ctx.Response.Header.Server() - if string(name) != serverName { - fmt.Fprintf(ctx, "unexpected server name: %q. Expecting %q", name, serverName) - } else { - ctx.WriteString("OK") - } - - // make sure the server name is sent to the client after ctx.Response.Reset() - ctx.NotFound() }, - Name: serverName, + Logger: &testLogger{}, + ReadTimeout: time.Millisecond * 100, } ln := fasthttputil.NewInmemoryListener() + defer ln.Close() - serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } - close(serverCh) }() - clientCh := make(chan struct{}) + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + defer c.Close() + if _, err = c.Write([]byte("POST / HTTP/1.1\r\nHost: golang.org\r\nContent-Length: 3\r\n\r\nABC" + + "\r\n\r\n", // <-- this stuff is bogus, but we'll ignore it + )); err != nil { + t.Fatal(err) + } + + br := bufio.NewReader(c) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + if err := resp.Read(br); err == nil { + t.Fatal("expected error") // We didn't send a request so we should get an error here. + } +} + +func TestServerPipelineFlush(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + }, + } + ln := fasthttputil.NewInmemoryListener() + go func() { - c, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { - t.Fatalf("unexpected error: %s", err) - } - br := bufio.NewReader(c) - var resp Response - if err = resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) } + }() - if resp.StatusCode() != StatusNotFound { - t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotFound) - } - if string(resp.Body()) != "404 Page not found" { - t.Fatalf("unexpected body: %q. Expecting %q", resp.Body(), "404 Page not found") - } - if string(resp.Header.Server()) != serverName { - t.Fatalf("unexpected server header: %q. Expecting %q", resp.Header.Server(), serverName) - } - if err = c.Close(); err != nil { - t.Fatalf("unexpected error: %s", err) + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatal(err) + } + + // Write a partial request. + if _, err = c.Write([]byte("GET /foo1 HTTP/1.1\r\nHost: ")); err != nil { + t.Fatal(err) + } + go func() { + // Wait for 200ms to finish the request + time.Sleep(time.Millisecond * 200) + + if _, err = c.Write([]byte("google.com\r\n\r\n")); err != nil { + t.Error(err) } - close(clientCh) }() - select { - case <-clientCh: - case <-time.After(time.Second): - t.Fatalf("timeout") - } + start := time.Now() + br := bufio.NewReader(c) + var resp Response - if err := ln.Close(); err != nil { + if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } - select { - case <-serverCh: - case <-time.After(time.Second): - t.Fatalf("timeout") + // Since the second request takes 200ms to finish we expect the first one to be flushed earlier. + d := time.Since(start) + if d > time.Millisecond*100 { + t.Fatalf("had to wait for %v", d) + } + + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } } -func TestServerResponseBodyStream(t *testing.T) { - ln := fasthttputil.NewInmemoryListener() +func TestServerInvalidHeader(t *testing.T) { + t.Parallel() - readyCh := make(chan struct{}) - h := func(ctx *RequestCtx) { - ctx.SetConnectionClose() - if ctx.IsBodyStream() { - t.Fatalf("IsBodyStream must return false") - } - ctx.SetBodyStreamWriter(func(w *bufio.Writer) { - fmt.Fprintf(w, "first") - if err := w.Flush(); err != nil { - return + s := &Server{ + Handler: func(ctx *RequestCtx) { + if ctx.Request.Header.Peek("Foo") != nil || ctx.Request.Header.Peek("Foo ") != nil { + t.Error("expected Foo header") } - <-readyCh - fmt.Fprintf(w, "second") - // there is no need to flush w here, since it will - // be flushed automatically after returning from StreamWriter. - }) - if !ctx.IsBodyStream() { - t.Fatalf("IsBodyStream must return true") - } + }, + Logger: &testLogger{}, } - serverCh := make(chan struct{}) + ln := fasthttputil.NewInmemoryListener() + go func() { - if err := Serve(ln, h); err != nil { - t.Fatalf("unexpected error: %s", err) + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) } - close(serverCh) }() - clientCh := make(chan struct{}) - go func() { - c, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { - t.Fatalf("unexpected error: %s", err) - } - br := bufio.NewReader(c) - var respH ResponseHeader - if err = respH.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) - } - if respH.StatusCode() != StatusOK { - t.Fatalf("unexpected status code: %d. Expecting %d", respH.StatusCode(), StatusOK) - } + c, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("POST /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\nContent-Length: 5\r\n\r\n12345")); err != nil { + t.Fatal(err) + } - buf := make([]byte, 1024) - n, err := br.Read(buf) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - b := buf[:n] - if string(b) != "5\r\nfirst\r\n" { - t.Fatalf("unexpected result %q. Expecting %q", b, "5\r\nfirst\r\n") - } - close(readyCh) + br := bufio.NewReader(c) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusBadRequest { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest) + } - tail, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if string(tail) != "6\r\nsecond\r\n0\r\n\r\n" { - t.Fatalf("unexpected tail %q. Expecting %q", tail, "6\r\nsecond\r\n0\r\n\r\n") - } + c, err = ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("GET /foo HTTP/1.1\r\nHost: gle.com\r\nFoo : bar\r\n\r\n")); err != nil { + t.Fatal(err) + } - close(clientCh) - }() + br = bufio.NewReader(c) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } - select { - case <-clientCh: - case <-time.After(time.Second): - t.Fatalf("timeout") + if resp.StatusCode() != StatusBadRequest { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusBadRequest) } - if err := ln.Close(); err != nil { + if err := c.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } - - select { - case <-serverCh: - case <-time.After(time.Second): - t.Fatalf("timeout") + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) } } -func TestServerDisableKeepalive(t *testing.T) { +func TestServerConnState(t *testing.T) { + t.Parallel() + + states := make([]string, 0) s := &Server{ - Handler: func(ctx *RequestCtx) { - ctx.WriteString("OK") + Handler: func(ctx *RequestCtx) {}, + ConnState: func(conn net.Conn, state ConnState) { + states = append(states, state.String()) }, - DisableKeepalive: true, } ln := fasthttputil.NewInmemoryListener() @@ -226,7 +250,7 @@ serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } close(serverCh) }() @@ -235,42 +259,34 @@ go func() { c, err := ln.Dial() if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } br := bufio.NewReader(c) - var resp Response - if err = resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) - } - if resp.StatusCode() != StatusOK { - t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) - } - if !resp.ConnectionClose() { - t.Fatalf("expecting 'Connection: close' response header") + // Send 2 requests on the same connection. + for i := 0; i < 2; i++ { + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %s", err) + } + var resp Response + if err := resp.Read(br); err != nil { + t.Errorf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } } - if string(resp.Body()) != "OK" { - t.Fatalf("unexpected body: %q. Expecting %q", resp.Body(), "OK") + if err := c.Close(); err != nil { + t.Errorf("unexpected error: %s", err) } - - // make sure the connection is closed - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if len(data) > 0 { - t.Fatalf("unexpected data read from the connection: %q. Expecting empty data", data) - } - - close(clientCh) - }() + // Give the server a little bit of time to transition the connection to the close state. + time.Sleep(time.Millisecond * 100) + close(clientCh) + }() select { case <-clientCh: case <-time.After(time.Second): - t.Fatalf("timeout") + t.Fatal("timeout") } if err := ln.Close(); err != nil { @@ -280,119 +296,309 @@ select { case <-serverCh: case <-time.After(time.Second): - t.Fatalf("timeout") + t.Fatal("timeout") + } + + // 2 requests so we go to active and idle twice. + expected := []string{"new", "active", "idle", "active", "idle", "closed"} + + if !reflect.DeepEqual(expected, states) { + t.Fatalf("wrong state, expected %s, got %s", expected, states) } } -func TestServerMaxConnsPerIPLimit(t *testing.T) { +func TestSaveMultipartFile(t *testing.T) { + t.Parallel() + + filea := "This is a test file." + fileb := strings.Repeat("test", 64) + + mr := multipart.NewReader(strings.NewReader(""+ + "--foo\r\n"+ + "Content-Disposition: form-data; name=\"filea\"; filename=\"filea.txt\"\r\n"+ + "Content-Type: text/plain\r\n"+ + "\r\n"+ + filea+"\r\n"+ + "--foo\r\n"+ + "Content-Disposition: form-data; name=\"fileb\"; filename=\"fileb.txt\"\r\n"+ + "Content-Type: text/plain\r\n"+ + "\r\n"+ + fileb+"\r\n"+ + "--foo--\r\n", + ), "foo") + + f, err := mr.ReadForm(64) + if err != nil { + t.Fatal(err) + } + + if err := SaveMultipartFile(f.File["filea"][0], "filea.txt"); err != nil { + t.Fatal(err) + } + defer os.Remove("filea.txt") + + if c, err := ioutil.ReadFile("filea.txt"); err != nil { + t.Fatal(err) + } else if string(c) != filea { + t.Fatalf("filea changed expected %q got %q", filea, c) + } + + // Make sure fileb was saved to a file. + if ff, err := f.File["fileb"][0].Open(); err != nil { + t.Fatal("expected FileHeader.Open to work") + } else if _, ok := ff.(*os.File); !ok { + t.Fatal("expected fileb to be an os.File") + } else { + ff.Close() + } + + if err := SaveMultipartFile(f.File["fileb"][0], "fileb.txt"); err != nil { + t.Fatal(err) + } + defer os.Remove("fileb.txt") + + if c, err := ioutil.ReadFile("fileb.txt"); err != nil { + t.Fatal(err) + } else if string(c) != fileb { + t.Fatalf("fileb changed expected %q got %q", fileb, c) + } +} + +func TestServerName(t *testing.T) { + t.Parallel() + s := &Server{ Handler: func(ctx *RequestCtx) { - ctx.WriteString("OK") }, - MaxConnsPerIP: 1, - Logger: &customLogger{}, } - ln := fasthttputil.NewInmemoryListener() + getReponse := func() []byte { + rw := &readWriter{} + rw.r.WriteString("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n") - serverCh := make(chan struct{}) - go func() { - fakeLN := &fakeIPListener{ - Listener: ln, - } - if err := s.Serve(fakeLN); err != nil { - t.Fatalf("unexpected error: %s", err) + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } - close(serverCh) - }() - clientCh := make(chan struct{}) - go func() { - c1, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - c2, err := ln.Dial() + resp, err := ioutil.ReadAll(&rw.w) if err != nil { - t.Fatalf("unexpected error: %s", err) - } - br := bufio.NewReader(c2) - var resp Response - if err = resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) - } - if resp.StatusCode() != StatusTooManyRequests { - t.Fatalf("unexpected status code for the second connection: %d. Expecting %d", - resp.StatusCode(), StatusTooManyRequests) + t.Fatalf("Unexpected error from ReadAll: %s", err) } - if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { - t.Fatalf("unexpected error when writing to the first connection: %s", err) - } - br = bufio.NewReader(c1) - if err = resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) - } - if resp.StatusCode() != StatusOK { - t.Fatalf("unexpected status code for the first connection: %d. Expecting %d", - resp.StatusCode(), StatusOK) - } - if string(resp.Body()) != "OK" { - t.Fatalf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK") - } - close(clientCh) + return resp + } + + resp := getReponse() + if !bytes.Contains(resp, []byte("\r\nServer: "+string(defaultServerName)+"\r\n")) { + t.Fatalf("Unexpected response %q expected Server: "+string(defaultServerName), resp) + } + + // We can't just overwrite s.Name as fasthttp caches the name in an atomic.Value + s = &Server{ + Handler: func(ctx *RequestCtx) { + }, + Name: "foobar", + } + + resp = getReponse() + if !bytes.Contains(resp, []byte("\r\nServer: foobar\r\n")) { + t.Fatalf("Unexpected response %q expected Server: foobar", resp) + } + + s = &Server{ + Handler: func(ctx *RequestCtx) { + }, + NoDefaultServerHeader: true, + NoDefaultContentType: true, + NoDefaultDate: true, + } + + resp = getReponse() + if bytes.Contains(resp, []byte("\r\nServer: ")) { + t.Fatalf("Unexpected response %q expected no Server header", resp) + } + + if bytes.Contains(resp, []byte("\r\nContent-Type: ")) { + t.Fatalf("Unexpected response %q expected no Content-Type header", resp) + } + + if bytes.Contains(resp, []byte("\r\nDate: ")) { + t.Fatalf("Unexpected response %q expected no Date header", resp) + } +} + +func TestRequestCtxString(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + + s := ctx.String() + expectedS := "#0000000000000000 - 0.0.0.0:0<->0.0.0.0:0 - GET http:///" + if s != expectedS { + t.Fatalf("unexpected ctx.String: %q. Expecting %q", s, expectedS) + } + + ctx.Request.SetRequestURI("https://foobar.com/aaa?bb=c") + s = ctx.String() + expectedS = "#0000000000000000 - 0.0.0.0:0<->0.0.0.0:0 - GET https://foobar.com/aaa?bb=c" + if s != expectedS { + t.Fatalf("unexpected ctx.String: %q. Expecting %q", s, expectedS) + } +} + +func TestServerErrSmallBuffer(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.WriteString("shouldn't be never called") //nolint:errcheck + }, + ReadBufferSize: 20, + } + + rw := &readWriter{} + rw.r.WriteString("GET / HTTP/1.1\r\nHost: aabb.com\r\nVERY-long-Header: sdfdfsd dsf dsaf dsf df fsd\r\n\r\n") + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) }() + var serverErr error select { - case <-clientCh: - case <-time.After(time.Second): - t.Fatalf("timeout") + case serverErr = <-ch: + case <-time.After(200 * time.Millisecond): + t.Fatal("timeout") } - if err := ln.Close(); err != nil { + if serverErr == nil { + t.Fatal("expected error") + } + + br := bufio.NewReader(&rw.w) + var resp Response + if err := resp.Read(br); err != nil { t.Fatalf("unexpected error: %s", err) } + statusCode := resp.StatusCode() + if statusCode != StatusRequestHeaderFieldsTooLarge { + t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusRequestHeaderFieldsTooLarge) + } + if !resp.ConnectionClose() { + t.Fatal("missing 'Connection: close' response header") + } - select { - case <-serverCh: - case <-time.After(time.Second): - t.Fatalf("timeout") + expectedErr := errSmallBuffer.Error() + if !strings.Contains(serverErr.Error(), expectedErr) { + t.Fatalf("unexpected log output: %v. Expecting %q", serverErr, expectedErr) } } -type fakeIPListener struct { - net.Listener +func TestRequestCtxIsTLS(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + + // tls.Conn + ctx.c = &tls.Conn{} + if !ctx.IsTLS() { + t.Fatal("IsTLS must return true") + } + + // non-tls.Conn + ctx.c = &readWriter{} + if ctx.IsTLS() { + t.Fatal("IsTLS must return false") + } + + // overridden tls.Conn + ctx.c = &struct { + *tls.Conn + fooBar bool + }{} + if !ctx.IsTLS() { + t.Fatal("IsTLS must return true") + } + + ctx.c = &perIPConn{Conn: &tls.Conn{}} + if !ctx.IsTLS() { + t.Fatal("IsTLS must return true") + } } -func (ln *fakeIPListener) Accept() (net.Conn, error) { - conn, err := ln.Listener.Accept() - if err != nil { - return nil, err +func TestRequestCtxRedirectHTTPSSchemeless(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + + s := "GET /foo/bar?baz HTTP/1.1\nHost: aaa.com\n\n" + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := ctx.Request.Read(br); err != nil { + t.Fatalf("cannot read request: %s", err) + } + ctx.Request.isTLS = true + + ctx.Redirect("//foobar.com/aa/bbb", StatusFound) + location := ctx.Response.Header.Peek(HeaderLocation) + expectedLocation := "https://foobar.com/aa/bbb" + if string(location) != expectedLocation { + t.Fatalf("Unexpected location: %q. Expecting %q", location, expectedLocation) } - return &fakeIPConn{ - Conn: conn, - }, nil } -type fakeIPConn struct { - net.Conn +func TestRequestCtxRedirect(t *testing.T) { + t.Parallel() + + testRequestCtxRedirect(t, "http://qqq/", "", "http://qqq/") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "", "http://qqq/foo/bar?baz=111") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "#aaa", "http://qqq/foo/bar?baz=111#aaa") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f", "http://qqq/foo/bar?abc=de&f") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "?abc=de&f#sf", "http://qqq/foo/bar?abc=de&f#sf") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html", "http://qqq/foo/x.html") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?a=1", "http://qqq/foo/x.html?a=1") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html#aaa=bbb&cc=ddd", "http://qqq/foo/x.html#aaa=bbb&cc=ddd") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "x.html?b=1#aaa=bbb&cc=ddd", "http://qqq/foo/x.html?b=1#aaa=bbb&cc=ddd") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html", "http://qqq/x.html") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "/x.html#aaa=bbb&cc=ddd", "http://qqq/x.html#aaa=bbb&cc=ddd") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../x.html", "http://qqq/x.html") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "../../x.html", "http://qqq/x.html") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "./.././../x.html", "http://qqq/x.html") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "http://foo.bar/baz", "http://foo.bar/baz") + testRequestCtxRedirect(t, "http://qqq/foo/bar?baz=111", "https://foo.bar/baz", "https://foo.bar/baz") + testRequestCtxRedirect(t, "https://foo.com/bar?aaa", "//google.com/aaa?bb", "https://google.com/aaa?bb") } -func (conn *fakeIPConn) RemoteAddr() net.Addr { - addr, err := net.ResolveTCPAddr("tcp4", "1.2.3.4:5789") - if err != nil { - panic(fmt.Sprintf("BUG: unexpected error: %s", err)) +func testRequestCtxRedirect(t *testing.T, origURL, redirectURL, expectedURL string) { + var ctx RequestCtx + var req Request + req.SetRequestURI(origURL) + ctx.Init(&req, nil, nil) + + ctx.Redirect(redirectURL, StatusFound) + loc := ctx.Response.Header.Peek(HeaderLocation) + if string(loc) != expectedURL { + t.Fatalf("unexpected redirect url %q. Expecting %q. origURL=%q, redirectURL=%q", loc, expectedURL, origURL, redirectURL) } - return addr } -func TestServerConcurrencyLimit(t *testing.T) { +func TestServerResponseServerHeader(t *testing.T) { + t.Parallel() + + serverName := "foobar serv" + s := &Server{ Handler: func(ctx *RequestCtx) { - ctx.WriteString("OK") + name := ctx.Response.Header.Server() + if string(name) != serverName { + fmt.Fprintf(ctx, "unexpected server name: %q. Expecting %q", name, serverName) + } else { + ctx.WriteString("OK") //nolint:errcheck + } + + // make sure the server name is sent to the client after ctx.Response.Reset() + ctx.NotFound() }, - Concurrency: 1, - Logger: &customLogger{}, + Name: serverName, } ln := fasthttputil.NewInmemoryListener() @@ -400,44 +606,37 @@ serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } close(serverCh) }() clientCh := make(chan struct{}) go func() { - c1, err := ln.Dial() + c, err := ln.Dial() if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } - c2, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %s", err) } - br := bufio.NewReader(c2) + br := bufio.NewReader(c) var resp Response if err = resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) - } - if resp.StatusCode() != StatusServiceUnavailable { - t.Fatalf("unexpected status code for the second connection: %d. Expecting %d", - resp.StatusCode(), StatusServiceUnavailable) + t.Errorf("unexpected error: %s", err) } - if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { - t.Fatalf("unexpected error when writing to the first connection: %s", err) + if resp.StatusCode() != StatusNotFound { + t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotFound) } - br = bufio.NewReader(c1) - if err = resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) + if string(resp.Body()) != "404 Page not found" { + t.Errorf("unexpected body: %q. Expecting %q", resp.Body(), "404 Page not found") } - if resp.StatusCode() != StatusOK { - t.Fatalf("unexpected status code for the first connection: %d. Expecting %d", - resp.StatusCode(), StatusOK) + if string(resp.Header.Server()) != serverName { + t.Errorf("unexpected server header: %q. Expecting %q", resp.Header.Server(), serverName) } - if string(resp.Body()) != "OK" { - t.Fatalf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK") + if err = c.Close(); err != nil { + t.Errorf("unexpected error: %s", err) } close(clientCh) }() @@ -445,7 +644,7 @@ select { case <-clientCh: case <-time.After(time.Second): - t.Fatalf("timeout") + t.Fatal("timeout") } if err := ln.Close(); err != nil { @@ -455,1856 +654,3049 @@ select { case <-serverCh: case <-time.After(time.Second): - t.Fatalf("timeout") + t.Fatal("timeout") } } -func TestServerWriteFastError(t *testing.T) { - s := &Server{ - Name: "foobar", - } - var buf bytes.Buffer - expectedBody := "access denied" - s.writeFastError(&buf, StatusForbidden, expectedBody) - - br := bufio.NewReader(&buf) - var resp Response - if err := resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) - } - if resp.StatusCode() != StatusForbidden { - t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusForbidden) - } - body := resp.Body() - if string(body) != expectedBody { - t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody) - } - server := string(resp.Header.Server()) - if server != s.Name { - t.Fatalf("unexpected server: %q. Expecting %q", server, s.Name) - } - contentType := string(resp.Header.ContentType()) - if contentType != "text/plain" { - t.Fatalf("unexpected content-type: %q. Expecting %q", contentType, "text/plain") - } - if !resp.Header.ConnectionClose() { - t.Fatalf("expecting 'Connection: close' response header") - } -} +func TestServerResponseBodyStream(t *testing.T) { + t.Parallel() -func TestServerServeTLSEmbed(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - certFile := "./ssl-cert-snakeoil.pem" - keyFile := "./ssl-cert-snakeoil.key" - - certData, err := ioutil.ReadFile(certFile) - if err != nil { - t.Fatalf("unexpected error when reading %q: %s", certFile, err) - } - keyData, err := ioutil.ReadFile(keyFile) - if err != nil { - t.Fatalf("unexpected error when reading %q: %s", keyFile, err) + readyCh := make(chan struct{}) + h := func(ctx *RequestCtx) { + ctx.SetConnectionClose() + if ctx.IsBodyStream() { + t.Fatal("IsBodyStream must return false") + } + ctx.SetBodyStreamWriter(func(w *bufio.Writer) { + fmt.Fprintf(w, "first") + if err := w.Flush(); err != nil { + return + } + <-readyCh + fmt.Fprintf(w, "second") + // there is no need to flush w here, since it will + // be flushed automatically after returning from StreamWriter. + }) + if !ctx.IsBodyStream() { + t.Fatal("IsBodyStream must return true") + } } - // start the server - ch := make(chan struct{}) + serverCh := make(chan struct{}) go func() { - err := ServeTLSEmbed(ln, certData, keyData, func(ctx *RequestCtx) { - ctx.WriteString("success") - }) - if err != nil { - t.Fatalf("unexpected error: %s", err) + if err := Serve(ln, h); err != nil { + t.Errorf("unexpected error: %s", err) } - close(ch) + close(serverCh) }() - // establish connection to the server - conn, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - tlsConn := tls.Client(conn, &tls.Config{ - InsecureSkipVerify: true, - }) + clientCh := make(chan struct{}) + go func() { + c, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %s", err) + } + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %s", err) + } + br := bufio.NewReader(c) + var respH ResponseHeader + if err = respH.Read(br); err != nil { + t.Errorf("unexpected error: %s", err) + } + if respH.StatusCode() != StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d", respH.StatusCode(), StatusOK) + } - // send request - if _, err = tlsConn.Write([]byte("GET / HTTP/1.1\r\nHost: aaa\r\n\r\n")); err != nil { - t.Fatalf("unexpected error: %s", err) - } + buf := make([]byte, 1024) + n, err := br.Read(buf) + if err != nil { + t.Errorf("unexpected error: %s", err) + } + b := buf[:n] + if string(b) != "5\r\nfirst\r\n" { + t.Errorf("unexpected result %q. Expecting %q", b, "5\r\nfirst\r\n") + } + close(readyCh) - // read response - respCh := make(chan struct{}) - go func() { - br := bufio.NewReader(tlsConn) - var resp Response - if err := resp.Read(br); err != nil { - t.Fatalf("unexpected error") + tail, err := ioutil.ReadAll(br) + if err != nil { + t.Errorf("unexpected error: %s", err) } - body := resp.Body() - if string(body) != "success" { - t.Fatalf("unexpected response body %q. Expecting %q", body, "success") + if string(tail) != "6\r\nsecond\r\n0\r\n\r\n" { + t.Errorf("unexpected tail %q. Expecting %q", tail, "6\r\nsecond\r\n0\r\n\r\n") } - close(respCh) + + close(clientCh) }() + select { - case <-respCh: + case <-clientCh: case <-time.After(time.Second): - t.Fatalf("timeout") + t.Fatal("timeout") } - // close the server - if err = ln.Close(); err != nil { + if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } + select { - case <-ch: + case <-serverCh: case <-time.After(time.Second): - t.Fatalf("timeout") + t.Fatal("timeout") } } -func TestServerMultipartFormDataRequest(t *testing.T) { - reqS := `POST /upload HTTP/1.1 -Host: qwerty.com -Content-Length: 521 -Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg - -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="f1" - -value1 -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="fileaaa"; filename="TODO" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . - -------WebKitFormBoundaryJwfATyF8tmxSJnLg-- - -GET / HTTP/1.1 -Host: asbd -Connection: close - -` - - ln := fasthttputil.NewInmemoryListener() +func TestServerDisableKeepalive(t *testing.T) { + t.Parallel() s := &Server{ Handler: func(ctx *RequestCtx) { - switch string(ctx.Path()) { - case "/upload": - f, err := ctx.MultipartForm() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if len(f.Value) != 1 { - t.Fatalf("unexpected values %d. Expecting %d", len(f.Value), 1) - } - if len(f.File) != 1 { - t.Fatalf("unexpected file values %d. Expecting %d", len(f.File), 1) - } - fv := ctx.FormValue("f1") - if string(fv) != "value1" { - t.Fatalf("unexpected form value: %q. Expecting %q", fv, "value1") - } - ctx.Redirect("/", StatusSeeOther) - default: - ctx.WriteString("non-upload") - } + ctx.WriteString("OK") //nolint:errcheck }, + DisableKeepalive: true, } - ch := make(chan struct{}) + ln := fasthttputil.NewInmemoryListener() + + serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } - close(ch) + close(serverCh) }() - conn, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if _, err = conn.Write([]byte(reqS)); err != nil { - t.Fatalf("unexpected error: %s", err) - } - - var resp Response - br := bufio.NewReader(conn) - respCh := make(chan struct{}) + clientCh := make(chan struct{}) go func() { - if err := resp.Read(br); err != nil { - t.Fatalf("error when reading response: %s", err) - } - if resp.StatusCode() != StatusSeeOther { - t.Fatalf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther) + c, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %s", err) } - loc := resp.Header.Peek("Location") - if string(loc) != "http://qwerty.com/" { - t.Fatalf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/") + if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %s", err) } - - if err := resp.Read(br); err != nil { - t.Fatalf("error when reading the second response: %s", err) + br := bufio.NewReader(c) + var resp Response + if err = resp.Read(br); err != nil { + t.Errorf("unexpected error: %s", err) } if resp.StatusCode() != StatusOK { - t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) } - body := resp.Body() - if string(body) != "non-upload" { - t.Fatalf("unexpected body %q. Expecting %q", body, "non-upload") + if !resp.ConnectionClose() { + t.Error("expecting 'Connection: close' response header") } - close(respCh) + if string(resp.Body()) != "OK" { + t.Errorf("unexpected body: %q. Expecting %q", resp.Body(), "OK") + } + + // make sure the connection is closed + data, err := ioutil.ReadAll(br) + if err != nil { + t.Errorf("unexpected error: %s", err) + } + if len(data) > 0 { + t.Errorf("unexpected data read from the connection: %q. Expecting empty data", data) + } + + close(clientCh) }() select { - case <-respCh: + case <-clientCh: case <-time.After(time.Second): - t.Fatalf("timeout") + t.Fatal("timeout") } if err := ln.Close(); err != nil { - t.Fatalf("error when closing listener: %s", err) + t.Fatalf("unexpected error: %s", err) } select { - case <-ch: + case <-serverCh: case <-time.After(time.Second): - t.Fatalf("timeout when waiting for the server to stop") + t.Fatal("timeout") } } -func TestServerDisableHeaderNamesNormalizing(t *testing.T) { - headerName := "CASE-senSITive-HEAder-NAME" - headerNameLower := strings.ToLower(headerName) - headerValue := "foobar baz" +func TestServerMaxConnsPerIPLimit(t *testing.T) { + t.Parallel() + s := &Server{ Handler: func(ctx *RequestCtx) { - hv := ctx.Request.Header.Peek(headerName) - if string(hv) != headerValue { - t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue) - } - hv = ctx.Request.Header.Peek(headerNameLower) - if len(hv) > 0 { - t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv) - } - ctx.Response.Header.Set(headerName, headerValue) - ctx.WriteString("ok") - ctx.SetContentType("aaa") + ctx.WriteString("OK") //nolint:errcheck }, - DisableHeaderNamesNormalizing: true, + MaxConnsPerIP: 1, + Logger: &testLogger{}, } - rw := &readWriter{} - rw.r.WriteString(fmt.Sprintf("GET / HTTP/1.1\r\n%s: %s\r\nHost: google.com\r\n\r\n", headerName, headerValue)) + ln := fasthttputil.NewInmemoryListener() - ch := make(chan error) + serverCh := make(chan struct{}) go func() { - ch <- s.ServeConn(rw) + fakeLN := &fakeIPListener{ + Listener: ln, + } + if err := s.Serve(fakeLN); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(serverCh) }() - select { - case err := <-ch: + clientCh := make(chan struct{}) + go func() { + c1, err := ln.Dial() if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) + t.Errorf("unexpected error: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + c2, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %s", err) + } + br := bufio.NewReader(c2) + var resp Response + if err = resp.Read(br); err != nil { + t.Errorf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusTooManyRequests { + t.Errorf("unexpected status code for the second connection: %d. Expecting %d", + resp.StatusCode(), StatusTooManyRequests) + } + + if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error when writing to the first connection: %s", err) + } + br = bufio.NewReader(c1) + if err = resp.Read(br); err != nil { + t.Errorf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code for the first connection: %d. Expecting %d", + resp.StatusCode(), StatusOK) + } + if string(resp.Body()) != "OK" { + t.Errorf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK") + } + close(clientCh) + }() + + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatal("timeout") } - br := bufio.NewReader(&rw.w) - var resp Response - resp.Header.DisableNormalizing() - if err := resp.Read(br); err != nil { + if err := ln.Close(); err != nil { t.Fatalf("unexpected error: %s", err) } - hv := resp.Header.Peek(headerName) - if string(hv) != headerValue { - t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue) - } - hv = resp.Header.Peek(headerNameLower) - if len(hv) > 0 { - t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv) + select { + case <-serverCh: + case <-time.After(time.Second): + t.Fatal("timeout") } } -func TestServerReduceMemoryUsageSerial(t *testing.T) { - ln := fasthttputil.NewInmemoryListener() +type fakeIPListener struct { + net.Listener +} - s := &Server{ - Handler: func(ctx *RequestCtx) {}, - ReduceMemoryUsage: true, - } - - ch := make(chan struct{}) - go func() { - if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) - } - close(ch) - }() - - testServerRequests(t, ln) - - if err := ln.Close(); err != nil { - t.Fatalf("error when closing listener: %s", err) +func (ln *fakeIPListener) Accept() (net.Conn, error) { + conn, err := ln.Listener.Accept() + if err != nil { + return nil, err } + return &fakeIPConn{ + Conn: conn, + }, nil +} - select { - case <-ch: - case <-time.After(time.Second): - t.Fatalf("timeout when waiting for the server to stop") +type fakeIPConn struct { + net.Conn +} + +func (conn *fakeIPConn) RemoteAddr() net.Addr { + addr, err := net.ResolveTCPAddr("tcp4", "1.2.3.4:5789") + if err != nil { + panic(fmt.Sprintf("BUG: unexpected error: %s", err)) } + return addr } -func TestServerReduceMemoryUsageConcurrent(t *testing.T) { - ln := fasthttputil.NewInmemoryListener() +func TestServerConcurrencyLimit(t *testing.T) { + t.Parallel() s := &Server{ - Handler: func(ctx *RequestCtx) {}, - ReduceMemoryUsage: true, + Handler: func(ctx *RequestCtx) { + ctx.WriteString("OK") //nolint:errcheck + }, + Concurrency: 1, + Logger: &testLogger{}, } - ch := make(chan struct{}) + ln := fasthttputil.NewInmemoryListener() + + serverCh := make(chan struct{}) go func() { if err := s.Serve(ln); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } - close(ch) + close(serverCh) }() - gCh := make(chan struct{}) - for i := 0; i < 10; i++ { - go func() { - testServerRequests(t, ln) - gCh <- struct{}{} - }() - } - for i := 0; i < 10; i++ { - select { - case <-gCh: - case <-time.After(time.Second): - t.Fatalf("timeout on goroutine %d", i) + clientCh := make(chan struct{}) + go func() { + c1, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %s", err) } + c2, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %s", err) + } + br := bufio.NewReader(c2) + var resp Response + if err = resp.Read(br); err != nil { + t.Errorf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusServiceUnavailable { + t.Errorf("unexpected status code for the second connection: %d. Expecting %d", + resp.StatusCode(), StatusServiceUnavailable) + } + + if _, err = c1.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil { + t.Errorf("unexpected error when writing to the first connection: %s", err) + } + br = bufio.NewReader(c1) + if err = resp.Read(br); err != nil { + t.Errorf("unexpected error: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code for the first connection: %d. Expecting %d", + resp.StatusCode(), StatusOK) + } + if string(resp.Body()) != "OK" { + t.Errorf("unexpected body for the first connection: %q. Expecting %q", resp.Body(), "OK") + } + close(clientCh) + }() + + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatal("timeout") } if err := ln.Close(); err != nil { - t.Fatalf("error when closing listener: %s", err) + t.Fatalf("unexpected error: %s", err) } select { - case <-ch: + case <-serverCh: case <-time.After(time.Second): - t.Fatalf("timeout when waiting for the server to stop") + t.Fatal("timeout") } } -func testServerRequests(t *testing.T, ln *fasthttputil.InmemoryListener) { - conn, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) +func TestServerWriteFastError(t *testing.T) { + t.Parallel() + + s := &Server{ + Name: "foobar", } + var buf bytes.Buffer + expectedBody := "access denied" + s.writeFastError(&buf, StatusForbidden, expectedBody) - br := bufio.NewReader(conn) + br := bufio.NewReader(&buf) var resp Response - for i := 0; i < 10; i++ { - if _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: aaa\r\n\r\n"); err != nil { - t.Fatalf("unexpected error on iteration %d: %s", i, err) - } - - respCh := make(chan struct{}) - go func() { - if err = resp.Read(br); err != nil { - t.Fatalf("unexpected error when reading response on iteration %d: %s", i, err) - } - close(respCh) - }() - select { - case <-respCh: - case <-time.After(time.Second): - t.Fatalf("timeout on iteration %d", i) - } + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) } - - if err = conn.Close(); err != nil { - t.Fatalf("error when closing the connection: %s", err) + if resp.StatusCode() != StatusForbidden { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusForbidden) + } + body := resp.Body() + if string(body) != expectedBody { + t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody) + } + server := string(resp.Header.Server()) + if server != s.Name { + t.Fatalf("unexpected server: %q. Expecting %q", server, s.Name) + } + contentType := string(resp.Header.ContentType()) + if contentType != "text/plain" { + t.Fatalf("unexpected content-type: %q. Expecting %q", contentType, "text/plain") + } + if !resp.Header.ConnectionClose() { + t.Fatal("expecting 'Connection: close' response header") } } -func TestServerHTTP10ConnectionKeepAlive(t *testing.T) { +func TestServerTLS(t *testing.T) { + t.Parallel() + + text := []byte("Make fasthttp great again") ln := fasthttputil.NewInmemoryListener() - ch := make(chan struct{}) + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Write(text) //nolint:errcheck + }, + } + + certData, keyData, err := GenerateTestCertificate("localhost") + if err != nil { + t.Fatal(err) + } + + err = s.AppendCertEmbed(certData, keyData) + if err != nil { + t.Fatal(err) + } go func() { - err := Serve(ln, func(ctx *RequestCtx) { - if string(ctx.Path()) == "/close" { - ctx.SetConnectionClose() - } - }) + err = s.ServeTLS(ln, "", "") if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Error(err) } - close(ch) }() - conn, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) + c := &Client{ + ReadTimeout: time.Second * 2, + Dial: func(addr string) (net.Conn, error) { + return ln.Dial() + }, + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, } - _, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") + + req, res := AcquireRequest(), AcquireResponse() + req.SetRequestURI("https://some.url") + + err = c.Do(req, res) if err != nil { - t.Fatalf("error when writing request: %s", err) + t.Fatal(err) } - _, err = fmt.Fprintf(conn, "%s", "GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") - if err != nil { - t.Fatalf("error when writing request: %s", err) + if !bytes.Equal(text, res.Body()) { + t.Fatal("error transmitting information") } +} - br := bufio.NewReader(conn) - var resp Response - if err = resp.Read(br); err != nil { - t.Fatalf("error when reading response: %s", err) - } - if resp.ConnectionClose() { - t.Fatalf("response mustn't have 'Connection: close' header") - } - if err = resp.Read(br); err != nil { - t.Fatalf("error when reading response: %s", err) +func TestServerTLSReadTimeout(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + ReadTimeout: time.Millisecond * 500, + Logger: &testLogger{}, // Ignore log output. + Handler: func(ctx *RequestCtx) { + }, } - if !resp.ConnectionClose() { - t.Fatalf("response must have 'Connection: close' header") + + certData, keyData, err := GenerateTestCertificate("localhost") + if err != nil { + t.Fatal(err) } - tailCh := make(chan struct{}) + err = s.AppendCertEmbed(certData, keyData) + if err != nil { + t.Fatal(err) + } go func() { - tail, err := ioutil.ReadAll(br) + err = s.ServeTLS(ln, "", "") if err != nil { - t.Fatalf("error when reading tail: %s", err) + t.Error(err) } - if len(tail) > 0 { - t.Fatalf("unexpected non-zero tail %q", tail) - } - close(tailCh) }() - select { - case <-tailCh: - case <-time.After(time.Second): - t.Fatalf("timeout when reading tail") + c, err := ln.Dial() + if err != nil { + t.Error(err) } - if err = conn.Close(); err != nil { - t.Fatalf("error when closing the connection: %s", err) - } + r := make(chan error) - if err = ln.Close(); err != nil { - t.Fatalf("error when closing listener: %s", err) - } + go func() { + b := make([]byte, 1) + _, err := c.Read(b) + c.Close() + r <- err + }() select { - case <-ch: + case err = <-r: case <-time.After(time.Second): - t.Fatalf("timeout when waiting for the server to stop") + } + + if err == nil { + t.Error("server didn't close connection after timeout") } } -func TestServerHTTP10ConnectionClose(t *testing.T) { +func TestServerServeTLSEmbed(t *testing.T) { + t.Parallel() + ln := fasthttputil.NewInmemoryListener() + certData, keyData, err := GenerateTestCertificate("localhost") + if err != nil { + t.Fatal(err) + } + + // start the server ch := make(chan struct{}) go func() { - err := Serve(ln, func(ctx *RequestCtx) { - // The server must close the connection irregardless - // of request and response state set inside request - // handler, since the HTTP/1.0 request - // had no 'Connection: keep-alive' header. - ctx.Request.Header.ResetConnectionClose() - ctx.Request.Header.Set("Connection", "keep-alive") - ctx.Response.Header.ResetConnectionClose() - ctx.Response.Header.Set("Connection", "keep-alive") + err := ServeTLSEmbed(ln, certData, keyData, func(ctx *RequestCtx) { + if !ctx.IsTLS() { + ctx.Error("expecting tls", StatusBadRequest) + return + } + scheme := ctx.URI().Scheme() + if string(scheme) != "https" { + ctx.Error(fmt.Sprintf("unexpected scheme=%q. Expecting %q", scheme, "https"), StatusBadRequest) + return + } + ctx.WriteString("success") //nolint:errcheck }) if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } close(ch) }() + // establish connection to the server conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } - _, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\n\r\n") - if err != nil { - t.Fatalf("error when writing request: %s", err) - } - - br := bufio.NewReader(conn) - var resp Response - if err = resp.Read(br); err != nil { - t.Fatalf("error when reading response: %s", err) - } + tlsConn := tls.Client(conn, &tls.Config{ + InsecureSkipVerify: true, + }) - if !resp.ConnectionClose() { - t.Fatalf("HTTP1.0 response must have 'Connection: close' header") + // send request + if _, err = tlsConn.Write([]byte("GET / HTTP/1.1\r\nHost: aaa\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) } - tailCh := make(chan struct{}) + // read response + respCh := make(chan struct{}) go func() { - tail, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("error when reading tail: %s", err) + br := bufio.NewReader(tlsConn) + var resp Response + if err := resp.Read(br); err != nil { + t.Error("unexpected error") } - if len(tail) > 0 { - t.Fatalf("unexpected non-zero tail %q", tail) + body := resp.Body() + if string(body) != "success" { + t.Errorf("unexpected response body %q. Expecting %q", body, "success") } - close(tailCh) + close(respCh) }() - select { - case <-tailCh: + case <-respCh: case <-time.After(time.Second): - t.Fatalf("timeout when reading tail") - } - - if err = conn.Close(); err != nil { - t.Fatalf("error when closing the connection: %s", err) + t.Fatal("timeout") } + // close the server if err = ln.Close(); err != nil { - t.Fatalf("error when closing listener: %s", err) + t.Fatalf("unexpected error: %s", err) } - select { case <-ch: case <-time.After(time.Second): - t.Fatalf("timeout when waiting for the server to stop") + t.Fatal("timeout") } } -func TestRequestCtxFormValue(t *testing.T) { - var ctx RequestCtx - var req Request - req.SetRequestURI("/foo/bar?baz=123&aaa=bbb") - req.SetBodyString("qqq=port&mmm=sddd") - req.Header.SetContentType("application/x-www-form-urlencoded") +func TestServerMultipartFormDataRequest(t *testing.T) { + t.Parallel() - ctx.Init(&req, nil, nil) + for _, test := range []struct { + StreamRequestBody bool + DisablePreParseMultipartForm bool + }{ + {false, false}, + {false, true}, + {true, false}, + {true, true}, + } { + reqS := `POST /upload HTTP/1.1 +Host: qwerty.com +Content-Length: 521 +Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg - v := ctx.FormValue("baz") - if string(v) != "123" { - t.Fatalf("unexpected value %q. Expecting %q", v, "123") - } - v = ctx.FormValue("mmm") - if string(v) != "sddd" { - t.Fatalf("unexpected value %q. Expecting %q", v, "sddd") - } - v = ctx.FormValue("aaaasdfsdf") - if len(v) > 0 { - t.Fatalf("unexpected value for unknown key %q", v) +------WebKitFormBoundaryJwfATyF8tmxSJnLg +Content-Disposition: form-data; name="f1" + +value1 +------WebKitFormBoundaryJwfATyF8tmxSJnLg +Content-Disposition: form-data; name="fileaaa"; filename="TODO" +Content-Type: application/octet-stream + +- SessionClient with referer and cookies support. +- Client with requests' pipelining support. +- ProxyHandler similar to FSHandler. +- WebSockets. See https://tools.ietf.org/html/rfc6455 . +- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . + +------WebKitFormBoundaryJwfATyF8tmxSJnLg-- + +GET / HTTP/1.1 +Host: asbd +Connection: close + +` + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + StreamRequestBody: test.StreamRequestBody, + DisablePreParseMultipartForm: test.DisablePreParseMultipartForm, + Handler: func(ctx *RequestCtx) { + switch string(ctx.Path()) { + case "/upload": + f, err := ctx.MultipartForm() + if err != nil { + t.Errorf("unexpected error: %s", err) + } + if len(f.Value) != 1 { + t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1) + } + if len(f.File) != 1 { + t.Errorf("unexpected file values %d. Expecting %d", len(f.File), 1) + } + fv := ctx.FormValue("f1") + if string(fv) != "value1" { + t.Errorf("unexpected form value: %q. Expecting %q", fv, "value1") + } + ctx.Redirect("/", StatusSeeOther) + default: + ctx.WriteString("non-upload") //nolint:errcheck + } + }, + } + + ch := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(ch) + }() + + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = conn.Write([]byte(reqS)); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var resp Response + br := bufio.NewReader(conn) + respCh := make(chan struct{}) + go func() { + if err := resp.Read(br); err != nil { + t.Errorf("error when reading response: %s", err) + } + if resp.StatusCode() != StatusSeeOther { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther) + } + loc := resp.Header.Peek(HeaderLocation) + if string(loc) != "http://qwerty.com/" { + t.Errorf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/") + } + + if err := resp.Read(br); err != nil { + t.Errorf("error when reading the second response: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + body := resp.Body() + if string(body) != "non-upload" { + t.Errorf("unexpected body %q. Expecting %q", body, "non-upload") + } + close(respCh) + }() + + select { + case <-respCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + if err := ln.Close(); err != nil { + t.Fatalf("error when closing listener: %s", err) + } + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout when waiting for the server to stop") + } } } -func TestRequestCtxUserValue(t *testing.T) { - var ctx RequestCtx +func TestServerGetWithContent(t *testing.T) { + t.Parallel() - for i := 0; i < 5; i++ { - k := fmt.Sprintf("key-%d", i) - ctx.SetUserValue(k, i) + h := func(ctx *RequestCtx) { + ctx.Success("foo/bar", []byte("success")) } - for i := 5; i < 10; i++ { - k := fmt.Sprintf("key-%d", i) - ctx.SetUserValueBytes([]byte(k), i) + s := &Server{ + Handler: h, } - for i := 0; i < 10; i++ { - k := fmt.Sprintf("key-%d", i) - v := ctx.UserValue(k) - n, ok := v.(int) - if !ok || n != i { - t.Fatalf("unexpected value obtained for key %q: %v. Expecting %d", k, v, i) - } + rw := &readWriter{} + rw.r.WriteString("GET / HTTP/1.1\r\nHost: mm.com\r\nContent-Length: 5\r\n\r\nabcde") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + + resp := rw.w.String() + if !strings.HasSuffix(resp, "success") { + t.Fatalf("unexpected response %s.", resp) } } -func TestServerHeadRequest(t *testing.T) { +func TestServerDisableHeaderNamesNormalizing(t *testing.T) { + t.Parallel() + + headerName := "CASE-senSITive-HEAder-NAME" + headerNameLower := strings.ToLower(headerName) + headerValue := "foobar baz" s := &Server{ Handler: func(ctx *RequestCtx) { - fmt.Fprintf(ctx, "Request method is %q", ctx.Method()) - ctx.SetContentType("aaa/bbb") + hv := ctx.Request.Header.Peek(headerName) + if string(hv) != headerValue { + t.Errorf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue) + } + hv = ctx.Request.Header.Peek(headerNameLower) + if len(hv) > 0 { + t.Errorf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv) + } + ctx.Response.Header.Set(headerName, headerValue) + ctx.WriteString("ok") //nolint:errcheck + ctx.SetContentType("aaa") }, + DisableHeaderNamesNormalizing: true, } rw := &readWriter{} - rw.r.WriteString("HEAD /foobar HTTP/1.1\r\nHost: aaa.com\r\n\r\n") - - ch := make(chan error) - go func() { - ch <- s.ServeConn(rw) - }() + rw.r.WriteString(fmt.Sprintf("GET / HTTP/1.1\r\n%s: %s\r\nHost: google.com\r\n\r\n", headerName, headerValue)) - select { - case err := <-ch: - if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } br := bufio.NewReader(&rw.w) var resp Response - resp.SkipBody = true + resp.Header.DisableNormalizing() if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when parsing response: %s", err) - } - if resp.Header.StatusCode() != StatusOK { - t.Fatalf("unexpected status code: %d. Expecting %d", resp.Header.StatusCode(), StatusOK) - } - if len(resp.Body()) > 0 { - t.Fatalf("Unexpected non-zero body %q", resp.Body()) - } - if resp.Header.ContentLength() != 24 { - t.Fatalf("unexpected content-length %d. Expecting %d", resp.Header.ContentLength(), 24) - } - if string(resp.Header.ContentType()) != "aaa/bbb" { - t.Fatalf("unexpected content-type %q. Expecting %q", resp.Header.ContentType(), "aaa/bbb") + t.Fatalf("unexpected error: %s", err) } - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) + hv := resp.Header.Peek(headerName) + if string(hv) != headerValue { + t.Fatalf("unexpected header value for %q: %q. Expecting %q", headerName, hv, headerValue) } - if len(data) > 0 { - t.Fatalf("unexpected remaining data %q", data) + hv = resp.Header.Peek(headerNameLower) + if len(hv) > 0 { + t.Fatalf("unexpected header value for %q: %q. Expecting empty value", headerNameLower, hv) } } -func TestServerExpect100Continue(t *testing.T) { +func TestServerReduceMemoryUsageSerial(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + s := &Server{ - Handler: func(ctx *RequestCtx) { - if !ctx.IsPost() { - t.Fatalf("unexpected method %q. Expecting POST", ctx.Method()) - } - if string(ctx.Path()) != "/foo" { - t.Fatalf("unexpected path %q. Expecting %q", ctx.Path(), "/foo") - } - ct := ctx.Request.Header.ContentType() - if string(ct) != "a/b" { - t.Fatalf("unexpectected content-type: %q. Expecting %q", ct, "a/b") - } - if string(ctx.PostBody()) != "12345" { - t.Fatalf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345") - } - ctx.WriteString("foobar") - }, + Handler: func(ctx *RequestCtx) {}, + ReduceMemoryUsage: true, } - rw := &readWriter{} - rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") - - ch := make(chan error) + ch := make(chan struct{}) go func() { - ch <- s.ServeConn(rw) - }() - - select { - case err := <-ch: - if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } + close(ch) + }() - br := bufio.NewReader(&rw.w) - verifyResponse(t, br, StatusOK, string(defaultContentType), "foobar") + testServerRequests(t, ln) - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) + if err := ln.Close(); err != nil { + t.Fatalf("error when closing listener: %s", err) } - if len(data) > 0 { - t.Fatalf("unexpected remaining data %q", data) + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout when waiting for the server to stop") } } -func TestCompressHandler(t *testing.T) { - expectedBody := "foo/bar/baz" - h := CompressHandler(func(ctx *RequestCtx) { - ctx.Write([]byte(expectedBody)) - }) +func TestServerReduceMemoryUsageConcurrent(t *testing.T) { + t.Parallel() - var ctx RequestCtx - var resp Response + ln := fasthttputil.NewInmemoryListener() - // verify uncompressed response - h(&ctx) - s := ctx.Response.String() - br := bufio.NewReader(bytes.NewBufferString(s)) - if err := resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) - } - ce := resp.Header.Peek("Content-Encoding") - if string(ce) != "" { - t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "") - } - body := resp.Body() - if string(body) != expectedBody { - t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + s := &Server{ + Handler: func(ctx *RequestCtx) {}, + ReduceMemoryUsage: true, } - // verify gzip-compressed response - ctx.Request.Reset() - ctx.Response.Reset() - ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") + ch := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(ch) + }() - h(&ctx) - s = ctx.Response.String() - br = bufio.NewReader(bytes.NewBufferString(s)) - if err := resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) - } - ce = resp.Header.Peek("Content-Encoding") - if string(ce) != "gzip" { - t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") - } - body, err := resp.BodyGunzip() - if err != nil { - t.Fatalf("unexpected error: %s", err) + gCh := make(chan struct{}) + for i := 0; i < 10; i++ { + go func() { + testServerRequests(t, ln) + gCh <- struct{}{} + }() } - if string(body) != expectedBody { - t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + for i := 0; i < 10; i++ { + select { + case <-gCh: + case <-time.After(time.Second): + t.Fatalf("timeout on goroutine %d", i) + } } - // an attempt to compress already compressed response - ctx.Request.Reset() - ctx.Response.Reset() - ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") - hh := CompressHandler(h) - hh(&ctx) - s = ctx.Response.String() - br = bufio.NewReader(bytes.NewBufferString(s)) - if err := resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) + if err := ln.Close(); err != nil { + t.Fatalf("error when closing listener: %s", err) } - ce = resp.Header.Peek("Content-Encoding") - if string(ce) != "gzip" { - t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout when waiting for the server to stop") } - body, err = resp.BodyGunzip() +} + +func testServerRequests(t *testing.T, ln *fasthttputil.InmemoryListener) { + conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } - if string(body) != expectedBody { - t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) - } - // verify deflate-compressed response - ctx.Request.Reset() - ctx.Response.Reset() - ctx.Request.Header.Set("Accept-Encoding", "foobar, deflate, sdhc") + br := bufio.NewReader(conn) + var resp Response + for i := 0; i < 10; i++ { + if _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nHost: aaa\r\n\r\n"); err != nil { + t.Fatalf("unexpected error on iteration %d: %s", i, err) + } - h(&ctx) - s = ctx.Response.String() - br = bufio.NewReader(bytes.NewBufferString(s)) - if err := resp.Read(br); err != nil { - t.Fatalf("unexpected error: %s", err) - } - ce = resp.Header.Peek("Content-Encoding") - if string(ce) != "deflate" { - t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate") - } - body, err = resp.BodyInflate() - if err != nil { - t.Fatalf("unexpected error: %s", err) + respCh := make(chan struct{}) + go func() { + if err = resp.Read(br); err != nil { + t.Errorf("unexpected error when reading response on iteration %d: %s", i, err) + } + close(respCh) + }() + select { + case <-respCh: + case <-time.After(time.Second): + t.Fatalf("timeout on iteration %d", i) + } } - if string(body) != expectedBody { - t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + + if err = conn.Close(); err != nil { + t.Fatalf("error when closing the connection: %s", err) } } -func TestRequestCtxWriteString(t *testing.T) { - var ctx RequestCtx - n, err := ctx.WriteString("foo") +func TestServerHTTP10ConnectionKeepAlive(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + ch := make(chan struct{}) + go func() { + err := Serve(ln, func(ctx *RequestCtx) { + if string(ctx.Path()) == "/close" { + ctx.SetConnectionClose() + } + }) + if err != nil { + t.Errorf("unexpected error: %s", err) + } + close(ch) + }() + + conn, err := ln.Dial() if err != nil { t.Fatalf("unexpected error: %s", err) } - if n != 3 { - t.Fatalf("unexpected n %d. Expecting 3", n) + _, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") + if err != nil { + t.Fatalf("error when writing request: %s", err) } - n, err = ctx.WriteString("привет") + _, err = fmt.Fprintf(conn, "%s", "GET /close HTTP/1.0\r\nHost: aaa\r\nConnection: keep-alive\r\n\r\n") if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Fatalf("error when writing request: %s", err) } - if n != 12 { - t.Fatalf("unexpected n=%d. Expecting 12", n) + + br := bufio.NewReader(conn) + var resp Response + if err = resp.Read(br); err != nil { + t.Fatalf("error when reading response: %s", err) + } + if resp.ConnectionClose() { + t.Fatal("response mustn't have 'Connection: close' header") + } + if err = resp.Read(br); err != nil { + t.Fatalf("error when reading response: %s", err) + } + if !resp.ConnectionClose() { + t.Fatal("response must have 'Connection: close' header") } - s := ctx.Response.Body() - if string(s) != "fooпривет" { - t.Fatalf("unexpected response body %q. Expecting %q", s, "fooпривет") + tailCh := make(chan struct{}) + go func() { + tail, err := ioutil.ReadAll(br) + if err != nil { + t.Errorf("error when reading tail: %s", err) + } + if len(tail) > 0 { + t.Errorf("unexpected non-zero tail %q", tail) + } + close(tailCh) + }() + + select { + case <-tailCh: + case <-time.After(time.Second): + t.Fatal("timeout when reading tail") + } + + if err = conn.Close(); err != nil { + t.Fatalf("error when closing the connection: %s", err) + } + + if err = ln.Close(); err != nil { + t.Fatalf("error when closing listener: %s", err) + } + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout when waiting for the server to stop") } } -func TestServeConnNonHTTP11KeepAlive(t *testing.T) { - rw := &readWriter{} - rw.r.WriteString("GET /foo HTTP/1.0\r\nConnection: keep-alive\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("GET /bar HTTP/1.0\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("GET /must/be/ignored HTTP/1.0\r\nHost: google.com\r\n\r\n") +func TestServerHTTP10ConnectionClose(t *testing.T) { + t.Parallel() - requestsServed := 0 + ln := fasthttputil.NewInmemoryListener() ch := make(chan struct{}) go func() { - err := ServeConn(rw, func(ctx *RequestCtx) { - requestsServed++ - ctx.SuccessString("aaa/bbb", "foobar") + err := Serve(ln, func(ctx *RequestCtx) { + // The server must close the connection irregardless + // of request and response state set inside request + // handler, since the HTTP/1.0 request + // had no 'Connection: keep-alive' header. + ctx.Request.Header.ResetConnectionClose() + ctx.Request.Header.Set(HeaderConnection, "keep-alive") + ctx.Response.Header.ResetConnectionClose() + ctx.Response.Header.Set(HeaderConnection, "keep-alive") }) if err != nil { - t.Fatalf("unexpected error in ServeConn: %s", err) + t.Errorf("unexpected error: %s", err) } close(ch) }() - select { - case <-ch: - case <-time.After(time.Second): - t.Fatalf("timeout") + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + _, err = fmt.Fprintf(conn, "%s", "GET / HTTP/1.0\r\nHost: aaa\r\n\r\n") + if err != nil { + t.Fatalf("error when writing request: %s", err) } - br := bufio.NewReader(&rw.w) - + br := bufio.NewReader(conn) var resp Response - - // verify the first response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when parsing response: %s", err) - } - if string(resp.Header.Peek("Connection")) != "keep-alive" { - t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek("Connection"), "keep-alive") - } - if resp.Header.ConnectionClose() { - t.Fatalf("unexpected Connection: close") + if err = resp.Read(br); err != nil { + t.Fatalf("error when reading response: %s", err) } - // verify the second response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when parsing response: %s", err) - } - if string(resp.Header.Peek("Connection")) != "close" { - t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek("Connection"), "close") + if !resp.ConnectionClose() { + t.Fatal("HTTP1.0 response must have 'Connection: close' header") } - if !resp.Header.ConnectionClose() { - t.Fatalf("expecting Connection: close") + + tailCh := make(chan struct{}) + go func() { + tail, err := ioutil.ReadAll(br) + if err != nil { + t.Errorf("error when reading tail: %s", err) + } + if len(tail) > 0 { + t.Errorf("unexpected non-zero tail %q", tail) + } + close(tailCh) + }() + + select { + case <-tailCh: + case <-time.After(time.Second): + t.Fatal("timeout when reading tail") } - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) + if err = conn.Close(); err != nil { + t.Fatalf("error when closing the connection: %s", err) } - if len(data) != 0 { - t.Fatalf("Unexpected data read after responses %q", data) + + if err = ln.Close(); err != nil { + t.Fatalf("error when closing listener: %s", err) } - if requestsServed != 2 { - t.Fatalf("unexpected number of requests served: %d. Expecting 2", requestsServed) + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout when waiting for the server to stop") } } -func TestRequestCtxSetBodyStreamWriter(t *testing.T) { +func TestRequestCtxFormValue(t *testing.T) { + t.Parallel() + var ctx RequestCtx var req Request - ctx.Init(&req, nil, defaultLogger) - - if ctx.IsBodyStream() { - t.Fatalf("IsBodyStream must return false") - } - ctx.SetBodyStreamWriter(func(w *bufio.Writer) { - fmt.Fprintf(w, "body writer line 1\n") - if err := w.Flush(); err != nil { - t.Fatalf("unexpected error: %s", err) - } - fmt.Fprintf(w, "body writer line 2\n") - }) - if !ctx.IsBodyStream() { - t.Fatalf("IsBodyStream must return true") - } + req.SetRequestURI("/foo/bar?baz=123&aaa=bbb") + req.SetBodyString("qqq=port&mmm=sddd") + req.Header.SetContentType("application/x-www-form-urlencoded") - s := ctx.Response.String() + ctx.Init(&req, nil, nil) - br := bufio.NewReader(bytes.NewBufferString(s)) - var resp Response - if err := resp.Read(br); err != nil { - t.Fatalf("Error when reading response: %s", err) + v := ctx.FormValue("baz") + if string(v) != "123" { + t.Fatalf("unexpected value %q. Expecting %q", v, "123") } - - body := string(resp.Body()) - expectedBody := "body writer line 1\nbody writer line 2\n" - if body != expectedBody { - t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody) + v = ctx.FormValue("mmm") + if string(v) != "sddd" { + t.Fatalf("unexpected value %q. Expecting %q", v, "sddd") + } + v = ctx.FormValue("aaaasdfsdf") + if len(v) > 0 { + t.Fatalf("unexpected value for unknown key %q", v) } } -func TestRequestCtxIfModifiedSince(t *testing.T) { - var ctx RequestCtx - var req Request - ctx.Init(&req, nil, defaultLogger) +func TestRequestCtxUserValue(t *testing.T) { + t.Parallel() - lastModified := time.Now().Add(-time.Hour) + var ctx RequestCtx - if !ctx.IfModifiedSince(lastModified) { - t.Fatalf("IfModifiedSince must return true for non-existing If-Modified-Since header") + for i := 0; i < 5; i++ { + k := fmt.Sprintf("key-%d", i) + ctx.SetUserValue(k, i) } - - ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified))) - - if ctx.IfModifiedSince(lastModified) { - t.Fatalf("If-Modified-Since current time must return false") + for i := 5; i < 10; i++ { + k := fmt.Sprintf("key-%d", i) + ctx.SetUserValueBytes([]byte(k), i) } - past := lastModified.Add(-time.Hour) - if ctx.IfModifiedSince(past) { - t.Fatalf("If-Modified-Since past time must return false") + for i := 0; i < 10; i++ { + k := fmt.Sprintf("key-%d", i) + v := ctx.UserValue(k) + n, ok := v.(int) + if !ok || n != i { + t.Fatalf("unexpected value obtained for key %q: %v. Expecting %d", k, v, i) + } + } + vlen := 0 + ctx.VisitUserValues(func(key []byte, value interface{}) { + vlen++ + v := ctx.UserValueBytes(key) + if v != value { + t.Fatalf("unexpected value obtained from VisitUserValues for key: %q, expecting: %#v but got: %#v", key, v, value) + } + }) + if len(ctx.userValues) != vlen { + t.Fatalf("the length of user values returned from VisitUserValues is not equal to the length of the userValues, expecting: %d but got: %d", len(ctx.userValues), vlen) } - future := lastModified.Add(time.Hour) - if !ctx.IfModifiedSince(future) { - t.Fatalf("If-Modified-Since future time must return true") + ctx.ResetUserValues() + for i := 0; i < 10; i++ { + k := fmt.Sprintf("key-%d", i) + v := ctx.UserValue(k) + if v != nil { + t.Fatalf("unexpected value obtained for key %q: %v. Expecting nil", k, v) + } } } -func TestRequestCtxSendFileNotModified(t *testing.T) { - var ctx RequestCtx - var req Request - ctx.Init(&req, nil, defaultLogger) +func TestServerHeadRequest(t *testing.T) { + t.Parallel() - filePath := "./server_test.go" - lastModified, err := FileLastModified(filePath) - if err != nil { - t.Fatalf("unexpected error: %s", err) + s := &Server{ + Handler: func(ctx *RequestCtx) { + fmt.Fprintf(ctx, "Request method is %q", ctx.Method()) + ctx.SetContentType("aaa/bbb") + }, } - ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified))) - ctx.SendFile(filePath) + rw := &readWriter{} + rw.r.WriteString("HEAD /foobar HTTP/1.1\r\nHost: aaa.com\r\n\r\n") - s := ctx.Response.String() + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + br := bufio.NewReader(&rw.w) var resp Response - br := bufio.NewReader(bytes.NewBufferString(s)) + resp.SkipBody = true if err := resp.Read(br); err != nil { - t.Fatalf("error when reading response: %s", err) + t.Fatalf("Unexpected error when parsing response: %s", err) } - if resp.StatusCode() != StatusNotModified { - t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotModified) + if resp.Header.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.Header.StatusCode(), StatusOK) } if len(resp.Body()) > 0 { - t.Fatalf("unexpected non-zero response body: %q", resp.Body()) - } -} - -func TestRequestCtxSendFileModified(t *testing.T) { - var ctx RequestCtx - var req Request - ctx.Init(&req, nil, defaultLogger) - - filePath := "./server_test.go" - lastModified, err := FileLastModified(filePath) - if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Fatalf("Unexpected non-zero body %q", resp.Body()) } - lastModified = lastModified.Add(-time.Hour) - ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified))) - - ctx.SendFile(filePath) - - s := ctx.Response.String() - - var resp Response - br := bufio.NewReader(bytes.NewBufferString(s)) - if err := resp.Read(br); err != nil { - t.Fatalf("error when reading response: %s", err) + if resp.Header.ContentLength() != 24 { + t.Fatalf("unexpected content-length %d. Expecting %d", resp.Header.ContentLength(), 24) } - if resp.StatusCode() != StatusOK { - t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + if string(resp.Header.ContentType()) != "aaa/bbb" { + t.Fatalf("unexpected content-type %q. Expecting %q", resp.Header.ContentType(), "aaa/bbb") } - f, err := os.Open(filePath) - if err != nil { - t.Fatalf("cannot open file: %s", err) - } - body, err := ioutil.ReadAll(f) - f.Close() + data, err := ioutil.ReadAll(br) if err != nil { - t.Fatalf("error when reading file: %s", err) + t.Fatalf("Unexpected error when reading remaining data: %s", err) } - - if !bytes.Equal(resp.Body(), body) { - t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body) + if len(data) > 0 { + t.Fatalf("unexpected remaining data %q", data) } } -func TestRequestCtxSendFile(t *testing.T) { - var ctx RequestCtx - var req Request - ctx.Init(&req, nil, defaultLogger) - - filePath := "./server_test.go" - ctx.SendFile(filePath) +func TestServerExpect100Continue(t *testing.T) { + t.Parallel() - w := &bytes.Buffer{} - bw := bufio.NewWriter(w) - if err := ctx.Response.Write(bw); err != nil { - t.Fatalf("error when writing response: %s", err) - } - if err := bw.Flush(); err != nil { - t.Fatalf("error when flushing response: %s", err) + s := &Server{ + Handler: func(ctx *RequestCtx) { + if !ctx.IsPost() { + t.Errorf("unexpected method %q. Expecting POST", ctx.Method()) + } + if string(ctx.Path()) != "/foo" { + t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo") + } + ct := ctx.Request.Header.ContentType() + if string(ct) != "a/b" { + t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b") + } + if string(ctx.PostBody()) != "12345" { + t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345") + } + ctx.WriteString("foobar") //nolint:errcheck + }, } - var resp Response - br := bufio.NewReader(w) - if err := resp.Read(br); err != nil { - t.Fatalf("error when reading response: %s", err) - } - if resp.StatusCode() != StatusOK { - t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) - } + rw := &readWriter{} + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") - f, err := os.Open(filePath) - if err != nil { - t.Fatalf("cannot open file: %s", err) + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } - body, err := ioutil.ReadAll(f) - f.Close() + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusOK, string(defaultContentType), "foobar") + + data, err := ioutil.ReadAll(br) if err != nil { - t.Fatalf("error when reading file: %s", err) + t.Fatalf("Unexpected error when reading remaining data: %s", err) } - - if !bytes.Equal(resp.Body(), body) { - t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body) + if len(data) > 0 { + t.Fatalf("unexpected remaining data %q", data) } } -func TestRequestCtxHijack(t *testing.T) { - hijackStartCh := make(chan struct{}) - hijackStopCh := make(chan struct{}) +func TestServerContinueHandler(t *testing.T) { + t.Parallel() + + acceptContentLength := 5 s := &Server{ - Handler: func(ctx *RequestCtx) { - ctx.Hijack(func(c net.Conn) { - <-hijackStartCh + ContinueHandler: func(headers *RequestHeader) bool { + if !headers.IsPost() { + t.Errorf("unexpected method %q. Expecting POST", headers.Method()) + } - b := make([]byte, 1) - // ping-pong echo via hijacked conn - for { - n, err := c.Read(b) - if n != 1 { - if err == io.EOF { - close(hijackStopCh) - return - } - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - t.Fatalf("unexpected number of bytes read: %d. Expecting 1", n) - } - if _, err = c.Write(b); err != nil { - t.Fatalf("unexpected error when writing data: %s", err) - } - } - }) - ctx.Success("foo/bar", []byte("hijack it!")) + ct := headers.ContentType() + if string(ct) != "a/b" { + t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b") + } + + // Pass on any request that isn't the accepted content length + return headers.contentLength == acceptContentLength + }, + Handler: func(ctx *RequestCtx) { + if ctx.Request.Header.contentLength != acceptContentLength { + t.Errorf("all requests with content-length: other than %d, should be denied", acceptContentLength) + } + if !ctx.IsPost() { + t.Errorf("unexpected method %q. Expecting POST", ctx.Method()) + } + if string(ctx.Path()) != "/foo" { + t.Errorf("unexpected path %q. Expecting %q", ctx.Path(), "/foo") + } + ct := ctx.Request.Header.ContentType() + if string(ct) != "a/b" { + t.Errorf("unexpectected content-type: %q. Expecting %q", ct, "a/b") + } + if string(ctx.PostBody()) != "12345" { + t.Errorf("unexpected body: %q. Expecting %q", ctx.PostBody(), "12345") + } + ctx.WriteString("foobar") //nolint:errcheck }, } - hijackedString := "foobar baz hijacked!!!" - rw := &readWriter{} - rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString(hijackedString) + sendRequest := func(rw *readWriter, expectedStatusCode int, expectedResponse string) { + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } - ch := make(chan error) - go func() { - ch <- s.ServeConn(rw) - }() + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, expectedStatusCode, string(defaultContentType), expectedResponse) - select { - case err := <-ch: + data, err := ioutil.ReadAll(br) if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if len(data) > 0 { + t.Fatalf("unexpected remaining data %q", data) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") } - br := bufio.NewReader(&rw.w) + // The same server should not fail when handling the three different types of requests + // Regular requests + // Expect 100 continue accepted + // Exepect 100 continue denied + rw := &readWriter{} + for i := 0; i < 25; i++ { + + // Regular requests without Expect 100 continue header + rw.r.Reset() + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") + sendRequest(rw, StatusOK, "foobar") + + // Regular Expect 100 continue reqeuests that are accepted + rw.r.Reset() + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 5\r\nContent-Type: a/b\r\n\r\n12345") + sendRequest(rw, StatusOK, "foobar") + + // Requests being denied + rw.r.Reset() + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: gle.com\r\nExpect: 100-continue\r\nContent-Length: 6\r\nContent-Type: a/b\r\n\r\n123456") + sendRequest(rw, StatusExpectationFailed, "") + } +} + +func TestCompressHandler(t *testing.T) { + t.Parallel() + + expectedBody := string(createFixedBody(2e4)) + h := CompressHandler(func(ctx *RequestCtx) { + ctx.Write([]byte(expectedBody)) //nolint:errcheck + }) + + var ctx RequestCtx + var resp Response + + // verify uncompressed response + h(&ctx) + s := ctx.Response.String() + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce := resp.Header.Peek(HeaderContentEncoding) + if string(ce) != "" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "") + } + body := resp.Body() + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // verify gzip-compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") + + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce = resp.Header.Peek(HeaderContentEncoding) + if string(ce) != "gzip" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") + } + body, err := resp.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // an attempt to compress already compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set("Accept-Encoding", "gzip, deflate, sdhc") + hh := CompressHandler(h) + hh(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce = resp.Header.Peek(HeaderContentEncoding) + if string(ce) != "gzip" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "gzip") + } + body, err = resp.BodyGunzip() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } + + // verify deflate-compressed response + ctx.Request.Reset() + ctx.Response.Reset() + ctx.Request.Header.Set(HeaderAcceptEncoding, "foobar, deflate, sdhc") + + h(&ctx) + s = ctx.Response.String() + br = bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + ce = resp.Header.Peek(HeaderContentEncoding) + if string(ce) != "deflate" { + t.Fatalf("unexpected Content-Encoding: %q. Expecting %q", ce, "deflate") + } + body, err = resp.BodyInflate() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if string(body) != expectedBody { + t.Fatalf("unexpected body %q. Expecting %q", body, expectedBody) + } +} + +func TestRequestCtxWriteString(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + n, err := ctx.WriteString("foo") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if n != 3 { + t.Fatalf("unexpected n %d. Expecting 3", n) + } + n, err = ctx.WriteString("привет") + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if n != 12 { + t.Fatalf("unexpected n=%d. Expecting 12", n) + } + + s := ctx.Response.Body() + if string(s) != "fooпривет" { + t.Fatalf("unexpected response body %q. Expecting %q", s, "fooпривет") + } +} + +func TestServeConnNonHTTP11KeepAlive(t *testing.T) { + t.Parallel() + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.0\r\nConnection: keep-alive\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /bar HTTP/1.0\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /must/be/ignored HTTP/1.0\r\nHost: google.com\r\n\r\n") + + requestsServed := 0 + + ch := make(chan struct{}) + go func() { + err := ServeConn(rw, func(ctx *RequestCtx) { + requestsServed++ + ctx.SuccessString("aaa/bbb", "foobar") + }) + if err != nil { + t.Errorf("unexpected error in ServeConn: %s", err) + } + close(ch) + }() + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + br := bufio.NewReader(&rw.w) + + var resp Response + + // verify the first response + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when parsing response: %s", err) + } + if string(resp.Header.Peek(HeaderConnection)) != "keep-alive" { + t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek(HeaderConnection), "keep-alive") + } + if resp.Header.ConnectionClose() { + t.Fatal("unexpected Connection: close") + } + + // verify the second response + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when parsing response: %s", err) + } + if string(resp.Header.Peek(HeaderConnection)) != "close" { + t.Fatalf("unexpected Connection header %q. Expecting %q", resp.Header.Peek(HeaderConnection), "close") + } + if !resp.Header.ConnectionClose() { + t.Fatal("expecting Connection: close") + } + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if len(data) != 0 { + t.Fatalf("Unexpected data read after responses %q", data) + } + + if requestsServed != 2 { + t.Fatalf("unexpected number of requests served: %d. Expecting 2", requestsServed) + } +} + +func TestRequestCtxSetBodyStreamWriter(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + ctx.Init(&req, nil, defaultLogger) + + if ctx.IsBodyStream() { + t.Fatal("IsBodyStream must return false") + } + ctx.SetBodyStreamWriter(func(w *bufio.Writer) { + fmt.Fprintf(w, "body writer line 1\n") + if err := w.Flush(); err != nil { + t.Errorf("unexpected error: %s", err) + } + fmt.Fprintf(w, "body writer line 2\n") + }) + if !ctx.IsBodyStream() { + t.Fatal("IsBodyStream must return true") + } + + s := ctx.Response.String() + + br := bufio.NewReader(bytes.NewBufferString(s)) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("Error when reading response: %s", err) + } + + body := string(resp.Body()) + expectedBody := "body writer line 1\nbody writer line 2\n" + if body != expectedBody { + t.Fatalf("unexpected body: %q. Expecting %q", body, expectedBody) + } +} + +func TestRequestCtxIfModifiedSince(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + ctx.Init(&req, nil, defaultLogger) + + lastModified := time.Now().Add(-time.Hour) + + if !ctx.IfModifiedSince(lastModified) { + t.Fatal("IfModifiedSince must return true for non-existing If-Modified-Since header") + } + + ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified))) + + if ctx.IfModifiedSince(lastModified) { + t.Fatal("If-Modified-Since current time must return false") + } + + past := lastModified.Add(-time.Hour) + if ctx.IfModifiedSince(past) { + t.Fatal("If-Modified-Since past time must return false") + } + + future := lastModified.Add(time.Hour) + if !ctx.IfModifiedSince(future) { + t.Fatal("If-Modified-Since future time must return true") + } +} + +func TestRequestCtxSendFileNotModified(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + ctx.Init(&req, nil, defaultLogger) + + filePath := "./server_test.go" + lastModified, err := FileLastModified(filePath) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified))) + + ctx.SendFile(filePath) + + s := ctx.Response.String() + + var resp Response + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("error when reading response: %s", err) + } + if resp.StatusCode() != StatusNotModified { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusNotModified) + } + if len(resp.Body()) > 0 { + t.Fatalf("unexpected non-zero response body: %q", resp.Body()) + } +} + +func TestRequestCtxSendFileModified(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + ctx.Init(&req, nil, defaultLogger) + + filePath := "./server_test.go" + lastModified, err := FileLastModified(filePath) + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + lastModified = lastModified.Add(-time.Hour) + ctx.Request.Header.Set("If-Modified-Since", string(AppendHTTPDate(nil, lastModified))) + + ctx.SendFile(filePath) + + s := ctx.Response.String() + + var resp Response + br := bufio.NewReader(bytes.NewBufferString(s)) + if err := resp.Read(br); err != nil { + t.Fatalf("error when reading response: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + f, err := os.Open(filePath) + if err != nil { + t.Fatalf("cannot open file: %s", err) + } + body, err := ioutil.ReadAll(f) + f.Close() + if err != nil { + t.Fatalf("error when reading file: %s", err) + } + + if !bytes.Equal(resp.Body(), body) { + t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body) + } +} + +func TestRequestCtxSendFile(t *testing.T) { + t.Parallel() + + var ctx RequestCtx + var req Request + ctx.Init(&req, nil, defaultLogger) + + filePath := "./server_test.go" + ctx.SendFile(filePath) + + w := &bytes.Buffer{} + bw := bufio.NewWriter(w) + if err := ctx.Response.Write(bw); err != nil { + t.Fatalf("error when writing response: %s", err) + } + if err := bw.Flush(); err != nil { + t.Fatalf("error when flushing response: %s", err) + } + + var resp Response + br := bufio.NewReader(w) + if err := resp.Read(br); err != nil { + t.Fatalf("error when reading response: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Fatalf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + f, err := os.Open(filePath) + if err != nil { + t.Fatalf("cannot open file: %s", err) + } + body, err := ioutil.ReadAll(f) + f.Close() + if err != nil { + t.Fatalf("error when reading file: %s", err) + } + + if !bytes.Equal(resp.Body(), body) { + t.Fatalf("unexpected response body: %q. Expecting %q", resp.Body(), body) + } +} + +func TestRequestCtxHijack(t *testing.T) { + t.Parallel() + + hijackStartCh := make(chan struct{}) + hijackStopCh := make(chan struct{}) + s := &Server{ + Handler: func(ctx *RequestCtx) { + if ctx.Hijacked() { + t.Error("connection mustn't be hijacked") + } + ctx.Hijack(func(c net.Conn) { + <-hijackStartCh + + b := make([]byte, 1) + // ping-pong echo via hijacked conn + for { + n, err := c.Read(b) + if n != 1 { + if err == io.EOF { + close(hijackStopCh) + return + } + if err != nil { + t.Errorf("unexpected error: %s", err) + } + t.Errorf("unexpected number of bytes read: %d. Expecting 1", n) + } + if _, err = c.Write(b); err != nil { + t.Errorf("unexpected error when writing data: %s", err) + } + } + }) + if !ctx.Hijacked() { + t.Error("connection must be hijacked") + } + ctx.Success("foo/bar", []byte("hijack it!")) + }, + } + + hijackedString := "foobar baz hijacked!!!" + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString(hijackedString) + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + + br := bufio.NewReader(&rw.w) verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!") - close(hijackStartCh) + close(hijackStartCh) + select { + case <-hijackStopCh: + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout") + } + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if string(data) != hijackedString { + t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString) + } +} + +func TestRequestCtxHijackNoResponse(t *testing.T) { + t.Parallel() + + hijackDone := make(chan error) + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Hijack(func(c net.Conn) { + _, err := c.Write([]byte("test")) + hijackDone <- err + }) + ctx.HijackSetNoResponse(true) + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + + select { + case err := <-hijackDone: + if err != nil { + t.Fatalf("Unexpected error from hijack: %s", err) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("timeout") + } + + if got := rw.w.String(); got != "test" { + t.Errorf(`expected "test", got %q`, got) + } +} + +func TestRequestCtxNoHijackNoResponse(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + io.WriteString(ctx, "test") //nolint:errcheck + ctx.HijackSetNoResponse(true) + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 0\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + + bf := bufio.NewReader( + strings.NewReader(rw.w.String()), + ) + resp := AcquireResponse() + resp.Read(bf) //nolint:errcheck + if got := string(resp.Body()); got != "test" { + t.Errorf(`expected "test", got %q`, got) + } +} + +func TestRequestCtxInit(t *testing.T) { + // This test can't run parallel as it modifies globalConnID. + + var ctx RequestCtx + var logger testLogger + globalConnID = 0x123456 + ctx.Init(&ctx.Request, zeroTCPAddr, &logger) + ip := ctx.RemoteIP() + if !ip.IsUnspecified() { + t.Fatalf("unexpected ip for bare RequestCtx: %q. Expected 0.0.0.0", ip) + } + ctx.Logger().Printf("foo bar %d", 10) + + expectedLog := "#0012345700000000 - 0.0.0.0:0<->0.0.0.0:0 - GET http:/// - foo bar 10\n" + if logger.out != expectedLog { + t.Fatalf("Unexpected log output: %q. Expected %q", logger.out, expectedLog) + } +} + +func TestTimeoutHandlerSuccess(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + h := func(ctx *RequestCtx) { + if string(ctx.Path()) == "/" { + ctx.Success("aaa/bbb", []byte("real response")) + } + } + s := &Server{ + Handler: TimeoutHandler(h, 10*time.Second, "timeout!!!"), + } + serverCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) + } + close(serverCh) + }() + + concurrency := 20 + clientCh := make(chan struct{}, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + conn, err := ln.Dial() + if err != nil { + t.Errorf("unexepcted error: %s", err) + } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %s", err) + } + br := bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + clientCh <- struct{}{} + }() + } + + for i := 0; i < concurrency; i++ { + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + select { + case <-serverCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func TestTimeoutHandlerTimeout(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + readyCh := make(chan struct{}) + doneCh := make(chan struct{}) + h := func(ctx *RequestCtx) { + ctx.Success("aaa/bbb", []byte("real response")) + <-readyCh + doneCh <- struct{}{} + } + s := &Server{ + Handler: TimeoutHandler(h, 20*time.Millisecond, "timeout!!!"), + } + serverCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) + } + close(serverCh) + }() + + concurrency := 20 + clientCh := make(chan struct{}, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + conn, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error: %s", err) + } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %s", err) + } + br := bufio.NewReader(conn) + verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!") + clientCh <- struct{}{} + }() + } + + for i := 0; i < concurrency; i++ { + select { + case <-clientCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + + close(readyCh) + for i := 0; i < concurrency; i++ { + select { + case <-doneCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } + } + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } + select { - case <-hijackStopCh: + case <-serverCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } +} + +func TestTimeoutHandlerTimeoutReuse(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + h := func(ctx *RequestCtx) { + if string(ctx.Path()) == "/timeout" { + time.Sleep(time.Second) + } + ctx.SetBodyString("ok") + } + s := &Server{ + Handler: TimeoutHandler(h, 500*time.Millisecond, "timeout!!!"), + } + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) + } + }() + + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + br := bufio.NewReader(conn) + if _, err = conn.Write([]byte("GET /timeout HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!") + + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + verifyResponse(t, br, StatusOK, string(defaultContentType), "ok") + + if err := ln.Close(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestServerGetOnly(t *testing.T) { + t.Parallel() + + h := func(ctx *RequestCtx) { + if !ctx.IsGet() { + t.Errorf("non-get request: %q", ctx.Method()) + } + ctx.Success("foo/bar", []byte("success")) + } + s := &Server{ + Handler: h, + GetOnly: true, + } + + rw := &readWriter{} + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345") + + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + + select { + case err := <-ch: + if err == nil { + t.Fatal("expecting error") + } + if err != ErrGetOnly { + t.Fatalf("Unexpected error from serveConn: %s. Expecting %s", err, ErrGetOnly) + } case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + t.Fatal("timeout") + } + + br := bufio.NewReader(&rw.w) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("unexpected error: %s", err) + } + statusCode := resp.StatusCode() + if statusCode != StatusBadRequest { + t.Fatalf("unexpected status code: %d. Expecting %d", statusCode, StatusBadRequest) + } + if !resp.ConnectionClose() { + t.Fatal("missing 'Connection: close' response header") + } +} + +func TestServerTimeoutErrorWithResponse(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + go func() { + ctx.Success("aaa/bbb", []byte("xxxyyy")) + }() + + var resp Response + + resp.SetStatusCode(123) + resp.SetBodyString("foobar. Should be ignored") + ctx.TimeoutErrorWithResponse(&resp) + + resp.SetStatusCode(456) + resp.ResetBody() + fmt.Fprintf(resp.BodyWriter(), "path=%s", ctx.Path()) + resp.Header.SetContentType("foo/bar") + ctx.TimeoutErrorWithResponse(&resp) + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, 456, "foo/bar", "path=/foo") + verifyResponse(t, br, 456, "foo/bar", "path=/bar") + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if len(data) != 0 { + t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + } +} + +func TestServerTimeoutErrorWithCode(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + go func() { + ctx.Success("aaa/bbb", []byte("xxxyyy")) + }() + ctx.TimeoutErrorWithCode("should be ignored", 234) + ctx.TimeoutErrorWithCode("stolen ctx", StatusBadRequest) + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx") + verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx") + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if len(data) != 0 { + t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + } +} + +func TestServerTimeoutError(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + go func() { + ctx.Success("aaa/bbb", []byte("xxxyyy")) + }() + ctx.TimeoutError("should be ignored") + ctx.TimeoutError("stolen ctx") + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx") + verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx") + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if len(data) != 0 { + t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + } +} + +func TestServerMaxRequestsPerConn(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) {}, + MaxRequestsPerConn: 1, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: aaa.com\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + + br := bufio.NewReader(&rw.w) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when parsing response: %s", err) + } + if !resp.ConnectionClose() { + t.Fatal("Response must have 'connection: close' header") + } + verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType)) + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if len(data) != 0 { + t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + } +} + +func TestServerConnectionClose(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.SetConnectionClose() + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /must/be/ignored HTTP/1.1\r\nHost: aaa.com\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) + } + + br := bufio.NewReader(&rw.w) + var resp Response + + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when parsing response: %s", err) + } + if !resp.ConnectionClose() { + t.Fatal("expecting Connection: close header") + } + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Fatalf("Unexpected error when reading remaining data: %s", err) + } + if len(data) != 0 { + t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + } +} + +func TestServerRequestNumAndTime(t *testing.T) { + t.Parallel() + + n := uint64(0) + var connT time.Time + s := &Server{ + Handler: func(ctx *RequestCtx) { + n++ + if ctx.ConnRequestNum() != n { + t.Errorf("unexpected request number: %d. Expecting %d", ctx.ConnRequestNum(), n) + } + if connT.IsZero() { + connT = ctx.ConnTime() + } + if ctx.ConnTime() != connT { + t.Errorf("unexpected serve conn time: %s. Expecting %s", ctx.ConnTime(), connT) + } + }, + } + + rw := &readWriter{} + rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("GET /baz HTTP/1.1\r\nHost: google.com\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } - if string(data) != hijackedString { - t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString) + + if n != 3 { + t.Fatalf("unexpected number of requests served: %d. Expecting %d", n, 3) } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, 200, string(defaultContentType), "") } -func TestRequestCtxInit(t *testing.T) { - var ctx RequestCtx - var logger customLogger - globalConnID = 0x123456 - ctx.Init(&ctx.Request, zeroTCPAddr, &logger) - ip := ctx.RemoteIP() - if !ip.IsUnspecified() { - t.Fatalf("unexpected ip for bare RequestCtx: %q. Expected 0.0.0.0", ip) +func TestServerEmptyResponse(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + // do nothing :) + }, } - ctx.Logger().Printf("foo bar %d", 10) - expectedLog := "#0012345700000000 - 0.0.0.0:0<->0.0.0.0:0 - GET http:/// - foo bar 10\n" - if logger.out != expectedLog { - t.Fatalf("Unexpected log output: %q. Expected %q", logger.out, expectedLog) + rw := &readWriter{} + rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") + + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, 200, string(defaultContentType), "") } -func TestTimeoutHandlerSuccess(t *testing.T) { - ln := fasthttputil.NewInmemoryListener() - h := func(ctx *RequestCtx) { - if string(ctx.Path()) == "/" { - ctx.Success("aaa/bbb", []byte("real response")) - } - } +func TestServerLogger(t *testing.T) { + // This test can't run parallel as it modifies globalConnID. + + cl := &testLogger{} s := &Server{ - Handler: TimeoutHandler(h, 10*time.Second, "timeout!!!"), + Handler: func(ctx *RequestCtx) { + logger := ctx.Logger() + h := &ctx.Request.Header + logger.Printf("begin") + ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, body=%q, remoteAddr=%s", + h.RequestURI(), ctx.Request.Body(), ctx.RemoteAddr()))) + logger.Printf("end") + }, + Logger: cl, } - serverCh := make(chan struct{}) - go func() { - if err := s.Serve(ln); err != nil { - t.Fatalf("unexepcted error: %s", err) - } - close(serverCh) - }() - concurrency := 20 - clientCh := make(chan struct{}, concurrency) - for i := 0; i < concurrency; i++ { - go func() { - conn, err := ln.Dial() - if err != nil { - t.Fatalf("unexepcted error: %s", err) - } - if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { - t.Fatalf("unexpected error: %s", err) - } - br := bufio.NewReader(conn) - verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") - clientCh <- struct{}{} - }() - } + rw := &readWriter{} + rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 5\r\nContent-Type: aa\r\n\r\nabcde") - for i := 0; i < concurrency; i++ { - select { - case <-clientCh: - case <-time.After(time.Second): - t.Fatalf("timeout") - } + rwx := &readWriterRemoteAddr{ + rw: rw, + addr: &net.TCPAddr{ + IP: []byte{1, 2, 3, 4}, + Port: 8765, + }, } - if err := ln.Close(); err != nil { - t.Fatalf("unexpected error: %s", err) + globalConnID = 0 + + if err := s.ServeConn(rwx); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } - select { - case <-serverCh: - case <-time.After(time.Second): - t.Fatalf("timeout") + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, body=\"\", remoteAddr=1.2.3.4:8765") + verifyResponse(t, br, 200, "text/html", "requestURI=/foo2, body=\"abcde\", remoteAddr=1.2.3.4:8765") + + expectedLogOut := `#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - begin +#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - end +#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - begin +#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - end +` + if cl.out != expectedLogOut { + t.Fatalf("Unexpected logger output: %q. Expected %q", cl.out, expectedLogOut) } } -func TestTimeoutHandlerTimeout(t *testing.T) { - ln := fasthttputil.NewInmemoryListener() - readyCh := make(chan struct{}) - doneCh := make(chan struct{}) - h := func(ctx *RequestCtx) { - ctx.Success("aaa/bbb", []byte("real response")) - <-readyCh - doneCh <- struct{}{} - } - s := &Server{ - Handler: TimeoutHandler(h, 20*time.Millisecond, "timeout!!!"), - } - serverCh := make(chan struct{}) - go func() { - if err := s.Serve(ln); err != nil { - t.Fatalf("unexepcted error: %s", err) - } - close(serverCh) - }() +func TestServerRemoteAddr(t *testing.T) { + t.Parallel() - concurrency := 20 - clientCh := make(chan struct{}, concurrency) - for i := 0; i < concurrency; i++ { - go func() { - conn, err := ln.Dial() - if err != nil { - t.Fatalf("unexepcted error: %s", err) - } - if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { - t.Fatalf("unexpected error: %s", err) - } - br := bufio.NewReader(conn) - verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "timeout!!!") - clientCh <- struct{}{} - }() + s := &Server{ + Handler: func(ctx *RequestCtx) { + h := &ctx.Request.Header + ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s", + h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP()))) + }, } - for i := 0; i < concurrency; i++ { - select { - case <-clientCh: - case <-time.After(time.Second): - t.Fatalf("timeout") - } - } + rw := &readWriter{} + rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") - close(readyCh) - for i := 0; i < concurrency; i++ { - select { - case <-doneCh: - case <-time.After(time.Second): - t.Fatalf("timeout") - } + rwx := &readWriterRemoteAddr{ + rw: rw, + addr: &net.TCPAddr{ + IP: []byte{1, 2, 3, 4}, + Port: 8765, + }, } - if err := ln.Close(); err != nil { - t.Fatalf("unexpected error: %s", err) + if err := s.ServeConn(rwx); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } - select { - case <-serverCh: - case <-time.After(time.Second): - t.Fatalf("timeout") - } + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4") } -func TestServerGetOnly(t *testing.T) { - h := func(ctx *RequestCtx) { - if !ctx.IsGet() { - t.Fatalf("non-get request: %q", ctx.Method()) +func TestServerCustomRemoteAddr(t *testing.T) { + t.Parallel() + + customRemoteAddrHandler := func(h RequestHandler) RequestHandler { + return func(ctx *RequestCtx) { + ctx.SetRemoteAddr(&net.TCPAddr{ + IP: []byte{1, 2, 3, 5}, + Port: 0, + }) + h(ctx) } - ctx.Success("foo/bar", []byte("success")) } + s := &Server{ - Handler: h, - GetOnly: true, + Handler: customRemoteAddrHandler(func(ctx *RequestCtx) { + h := &ctx.Request.Header + ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s", + h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP()))) + }), } rw := &readWriter{} - rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\nContent-Type: aaa\r\n\r\n12345") - - ch := make(chan error) - go func() { - ch <- s.ServeConn(rw) - }() + rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") - select { - case err := <-ch: - if err == nil { - t.Fatalf("expecting error") - } - if err != errGetOnly { - t.Fatalf("Unexpected error from serveConn: %s. Expecting %s", err, errGetOnly) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + rwx := &readWriterRemoteAddr{ + rw: rw, + addr: &net.TCPAddr{ + IP: []byte{1, 2, 3, 4}, + Port: 8765, + }, } - resp := rw.w.Bytes() - if len(resp) > 0 { - t.Fatalf("unexpected response %q. Expecting zero", resp) + if err := s.ServeConn(rwx); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.5:0, remoteIP=1.2.3.5") } -func TestServerTimeoutErrorWithResponse(t *testing.T) { - s := &Server{ - Handler: func(ctx *RequestCtx) { - go func() { - ctx.Success("aaa/bbb", []byte("xxxyyy")) - }() +type readWriterRemoteAddr struct { + net.Conn + rw io.ReadWriteCloser + addr net.Addr +} - var resp Response +func (rw *readWriterRemoteAddr) Close() error { + return rw.rw.Close() +} - resp.SetStatusCode(123) - resp.SetBodyString("foobar. Should be ignored") - ctx.TimeoutErrorWithResponse(&resp) +func (rw *readWriterRemoteAddr) Read(b []byte) (int, error) { + return rw.rw.Read(b) +} - resp.SetStatusCode(456) - resp.ResetBody() - fmt.Fprintf(resp.BodyWriter(), "path=%s", ctx.Path()) - resp.Header.SetContentType("foo/bar") - ctx.TimeoutErrorWithResponse(&resp) +func (rw *readWriterRemoteAddr) Write(b []byte) (int, error) { + return rw.rw.Write(b) +} + +func (rw *readWriterRemoteAddr) RemoteAddr() net.Addr { + return rw.addr +} + +func (rw *readWriterRemoteAddr) LocalAddr() net.Addr { + return rw.addr +} + +func TestServerConnError(t *testing.T) { + t.Parallel() + + s := &Server{ + Handler: func(ctx *RequestCtx) { + ctx.Error("foobar", 423) }, } rw := &readWriter{} - rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n") - - ch := make(chan error) - go func() { - ch <- s.ServeConn(rw) - }() + rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n") - select { - case err := <-ch: - if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } br := bufio.NewReader(&rw.w) - verifyResponse(t, br, 456, "foo/bar", "path=/foo") - - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) + var resp Response + if err := resp.Read(br); err != nil { + t.Fatalf("Unexpected error when reading response: %s", err) } - if len(data) != 0 { - t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + if resp.Header.StatusCode() != 423 { + t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 423) + } + if resp.Header.ContentLength() != 6 { + t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), 6) + } + if !bytes.Equal(resp.Header.Peek(HeaderContentType), defaultContentType) { + t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.Peek(HeaderContentType), defaultContentType) + } + if !bytes.Equal(resp.Body(), []byte("foobar")) { + t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar") } } -func TestServerTimeoutErrorWithCode(t *testing.T) { +func TestServeConnSingleRequest(t *testing.T) { + t.Parallel() + s := &Server{ Handler: func(ctx *RequestCtx) { - go func() { - ctx.Success("aaa/bbb", []byte("xxxyyy")) - }() - ctx.TimeoutErrorWithCode("should be ignored", 234) - ctx.TimeoutErrorWithCode("stolen ctx", StatusBadRequest) + h := &ctx.Request.Header + ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek(HeaderHost)))) }, } rw := &readWriter{} - rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") - - ch := make(chan error) - go func() { - ch <- s.ServeConn(rw) - }() + rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n") - select { - case err := <-ch: - if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } br := bufio.NewReader(&rw.w) - verifyResponse(t, br, StatusBadRequest, string(defaultContentType), "stolen ctx") - - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) - } - if len(data) != 0 { - t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") - } + verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com") } -func TestServerTimeoutError(t *testing.T) { +func TestServeConnMultiRequests(t *testing.T) { + t.Parallel() + s := &Server{ Handler: func(ctx *RequestCtx) { - go func() { - ctx.Success("aaa/bbb", []byte("xxxyyy")) - }() - ctx.TimeoutError("should be ignored") - ctx.TimeoutError("stolen ctx") + h := &ctx.Request.Header + ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek(HeaderHost)))) }, } rw := &readWriter{} - rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") - - ch := make(chan error) - go func() { - ch <- s.ServeConn(rw) - }() + rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\nGET /abc HTTP/1.1\r\nHost: foobar.com\r\n\r\n") - select { - case err := <-ch: - if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + if err := s.ServeConn(rw); err != nil { + t.Fatalf("Unexpected error from serveConn: %s", err) } br := bufio.NewReader(&rw.w) - verifyResponse(t, br, StatusRequestTimeout, string(defaultContentType), "stolen ctx") - - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) - } - if len(data) != 0 { - t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") - } + verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com") + verifyResponse(t, br, 200, "aaa", "requestURI=/abc, host=foobar.com") } -func TestServerMaxKeepaliveDuration(t *testing.T) { +func TestShutdown(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) { - time.Sleep(20 * time.Millisecond) + time.Sleep(time.Millisecond * 500) + ctx.Success("aaa/bbb", []byte("real response")) }, - MaxKeepaliveDuration: 10 * time.Millisecond, } - - rw := &readWriter{} - rw.r.WriteString("GET /aaa HTTP/1.1\r\nHost: aa.com\r\n\r\n") - rw.r.WriteString("GET /bbbb HTTP/1.1\r\nHost: bbb.com\r\n\r\n") - - ch := make(chan error) + serveCh := make(chan struct{}) go func() { - ch <- s.ServeConn(rw) + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) + } + _, err := ln.Dial() + if err == nil { + t.Error("server is still listening") + } + serveCh <- struct{}{} }() - - select { - case err := <-ch: + clientCh := make(chan struct{}) + go func() { + conn, err := ln.Dial() if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) + t.Errorf("unexepcted error: %s", err) + } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %s", err) + } + br := bufio.NewReader(conn) + resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + verifyResponseHeaderConnection(t, &resp.Header, "") + clientCh <- struct{}{} + }() + time.Sleep(time.Millisecond * 100) + shutdownCh := make(chan struct{}) + go func() { + if err := s.Shutdown(); err != nil { + t.Errorf("unexepcted error: %s", err) + } + shutdownCh <- struct{}{} + }() + done := 0 + for { + select { + case <-time.After(time.Second * 2): + t.Fatal("shutdown took too long") + case <-serveCh: + done++ + case <-clientCh: + done++ + case <-shutdownCh: + done++ + } + if done == 3 { + return } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") } +} - br := bufio.NewReader(&rw.w) - var resp Response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when parsing response: %s", err) - } - if !resp.ConnectionClose() { - t.Fatalf("Response must have 'connection: close' header") - } - verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType)) +func TestCloseOnShutdown(t *testing.T) { + t.Parallel() - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + time.Sleep(time.Millisecond * 500) + ctx.Success("aaa/bbb", []byte("real response")) + }, + CloseOnShutdown: true, } - if len(data) != 0 { - t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + serveCh := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) + } + _, err := ln.Dial() + if err == nil { + t.Error("server is still listening") + } + serveCh <- struct{}{} + }() + clientCh := make(chan struct{}) + go func() { + conn, err := ln.Dial() + if err != nil { + t.Errorf("unexepcted error: %s", err) + } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Errorf("unexpected error: %s", err) + } + br := bufio.NewReader(conn) + resp := verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + verifyResponseHeaderConnection(t, &resp.Header, "close") + clientCh <- struct{}{} + }() + time.Sleep(time.Millisecond * 100) + shutdownCh := make(chan struct{}) + go func() { + if err := s.Shutdown(); err != nil { + t.Errorf("unexepcted error: %s", err) + } + shutdownCh <- struct{}{} + }() + done := 0 + for { + select { + case <-time.After(time.Second): + t.Fatal("shutdown took too long") + case <-serveCh: + done++ + case <-clientCh: + done++ + case <-shutdownCh: + done++ + } + if done == 3 { + return + } } } -func TestServerMaxRequestsPerConn(t *testing.T) { +func TestShutdownReuse(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() s := &Server{ - Handler: func(ctx *RequestCtx) {}, - MaxRequestsPerConn: 1, + Handler: func(ctx *RequestCtx) { + ctx.Success("aaa/bbb", []byte("real response")) + }, + ReadTimeout: time.Millisecond * 100, + Logger: &testLogger{}, // Ignore log output. } - - rw := &readWriter{} - rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: aaa.com\r\n\r\n") - - ch := make(chan error) go func() { - ch <- s.ServeConn(rw) - }() - - select { - case err := <-ch: - if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + }() + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexepcted error: %s", err) } - - br := bufio.NewReader(&rw.w) - var resp Response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when parsing response: %s", err) + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) } - if !resp.ConnectionClose() { - t.Fatalf("Response must have 'connection: close' header") + br := bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + if err := s.Shutdown(); err != nil { + t.Fatalf("unexepcted error: %s", err) } - verifyResponseHeader(t, &resp.Header, 200, 0, string(defaultContentType)) - - data, err := ioutil.ReadAll(br) + ln = fasthttputil.NewInmemoryListener() + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) + } + }() + conn, err = ln.Dial() if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) + t.Fatalf("unexepcted error: %s", err) } - if len(data) != 0 { - t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + br = bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") + if err := s.Shutdown(); err != nil { + t.Fatalf("unexepcted error: %s", err) } } -func TestServerConnectionClose(t *testing.T) { +func TestShutdownDone(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() s := &Server{ Handler: func(ctx *RequestCtx) { - ctx.SetConnectionClose() + <-ctx.Done() + ctx.Success("aaa/bbb", []byte("real response")) }, } - - rw := &readWriter{} - rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("GET /must/be/ignored HTTP/1.1\r\nHost: aaa.com\r\n\r\n") - - ch := make(chan error) go func() { - ch <- s.ServeConn(rw) - }() - - select { - case err := <-ch: - if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + }() + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexepcted error: %s", err) } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + go func() { + // Shutdown won't return if the connection doesn't close, + // which doesn't happen until we read the response. + if err := s.Shutdown(); err != nil { + t.Errorf("unexepcted error: %s", err) + } + }() + // We can only reach this point and get a valid response + // if reading from ctx.Done() returned. + br := bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") +} - br := bufio.NewReader(&rw.w) - var resp Response +func TestShutdownErr(t *testing.T) { + t.Parallel() - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when parsing response: %s", err) - } - if !resp.ConnectionClose() { - t.Fatalf("expecting Connection: close header") + ln := fasthttputil.NewInmemoryListener() + s := &Server{ + Handler: func(ctx *RequestCtx) { + // This will panic, but I was not able to intercept with recover() + c, cancel := context.WithCancel(ctx) + defer cancel() + <-c.Done() + ctx.Success("aaa/bbb", []byte("real response")) + }, } - data, err := ioutil.ReadAll(br) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexepcted error: %s", err) + } + }() + conn, err := ln.Dial() if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) - } - if len(data) != 0 { - t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, "") + t.Fatalf("unexepcted error: %s", err) } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + go func() { + // Shutdown won't return if the connection doesn't close, + // which doesn't happen until we read the response. + if err := s.Shutdown(); err != nil { + t.Errorf("unexepcted error: %s", err) + } + }() + // We can only reach this point and get a valid response + // if reading from ctx.Done() returned. + br := bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") } -func TestServerRequestNumAndTime(t *testing.T) { - n := uint64(0) - var connT time.Time +func TestMultipleServe(t *testing.T) { + t.Parallel() + s := &Server{ Handler: func(ctx *RequestCtx) { - n++ - if ctx.ConnRequestNum() != n { - t.Fatalf("unexpected request number: %d. Expecting %d", ctx.ConnRequestNum(), n) - } - if connT.IsZero() { - connT = ctx.ConnTime() - } - if ctx.ConnTime() != connT { - t.Fatalf("unexpected serve conn time: %s. Expecting %s", ctx.ConnTime(), connT) - } + ctx.Success("aaa/bbb", []byte("real response")) }, } - rw := &readWriter{} - rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("GET /bar HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("GET /baz HTTP/1.1\r\nHost: google.com\r\n\r\n") + ln1 := fasthttputil.NewInmemoryListener() + ln2 := fasthttputil.NewInmemoryListener() - ch := make(chan error) go func() { - ch <- s.ServeConn(rw) + if err := s.Serve(ln1); err != nil { + t.Errorf("unexepcted error: %s", err) + } }() - - select { - case err := <-ch: - if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) + go func() { + if err := s.Serve(ln2); err != nil { + t.Errorf("unexepcted error: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") - } + }() - if n != 3 { - t.Fatalf("unexpected number of requests served: %d. Expecting %d", n, 3) + conn, err := ln1.Dial() + if err != nil { + t.Fatalf("unexepcted error: %s", err) + } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) } + br := bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") - br := bufio.NewReader(&rw.w) - verifyResponse(t, br, 200, string(defaultContentType), "") + conn, err = ln2.Dial() + if err != nil { + t.Fatalf("unexepcted error: %s", err) + } + if _, err = conn.Write([]byte("GET / HTTP/1.1\r\nHost: google.com\r\n\r\n")); err != nil { + t.Fatalf("unexpected error: %s", err) + } + br = bufio.NewReader(conn) + verifyResponse(t, br, StatusOK, "aaa/bbb", "real response") } -func TestServerEmptyResponse(t *testing.T) { +func TestMaxBodySizePerRequest(t *testing.T) { + t.Parallel() + s := &Server{ Handler: func(ctx *RequestCtx) { // do nothing :) }, + HeaderReceived: func(header *RequestHeader) RequestConfig { + return RequestConfig{ + MaxRequestBodySize: 5 << 10, + } + }, + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, + MaxRequestBodySize: 1 << 20, } rw := &readWriter{} - rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") - - ch := make(chan error) - go func() { - ch <- s.ServeConn(rw) - }() + rw.r.WriteString(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", (5<<10)+1, strings.Repeat("a", (5<<10)+1))) - select { - case err := <-ch: - if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) - } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + if err := s.ServeConn(rw); err != ErrBodyTooLarge { + t.Fatalf("Unexpected error from serveConn: %s", err) } - - br := bufio.NewReader(&rw.w) - verifyResponse(t, br, 200, string(defaultContentType), "") } -type customLogger struct { - lock sync.Mutex - out string -} +func TestStreamRequestBody(t *testing.T) { + t.Parallel() -func (cl *customLogger) Printf(format string, args ...interface{}) { - cl.lock.Lock() - cl.out += fmt.Sprintf(format, args...)[6:] + "\n" - cl.lock.Unlock() -} + part1 := strings.Repeat("1", 1<<15) + part2 := strings.Repeat("2", 1<<16) + contentLength := len(part1) + len(part2) + next := make(chan struct{}) -func TestServerLogger(t *testing.T) { - cl := &customLogger{} s := &Server{ Handler: func(ctx *RequestCtx) { - logger := ctx.Logger() - h := &ctx.Request.Header - logger.Printf("begin") - ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, body=%q, remoteAddr=%s", - h.RequestURI(), ctx.Request.Body(), ctx.RemoteAddr()))) - logger.Printf("end") + checkReader(t, ctx.RequestBodyStream(), part1) + close(next) + checkReader(t, ctx.RequestBodyStream(), part2) }, - Logger: cl, + StreamRequestBody: true, } - rw := &readWriter{} - rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: 5\r\nContent-Type: aa\r\n\r\nabcde") - - rwx := &readWriterRemoteAddr{ - rw: rw, - addr: &net.TCPAddr{ - IP: []byte{1, 2, 3, 4}, - Port: 8765, - }, + pipe := fasthttputil.NewPipeConns() + cc, sc := pipe.Conn1(), pipe.Conn2() + //write headers and part1 body + if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", contentLength))); err != nil { + t.Fatal(err) + } + if _, err := cc.Write([]byte(part1)); err != nil { + t.Fatal(err) } - globalConnID = 0 ch := make(chan error) go func() { - ch <- s.ServeConn(rwx) + ch <- s.ServeConn(sc) }() select { + case <-next: + case <-time.After(500 * time.Millisecond): + t.Fatal("part1 timeout") + } + + if _, err := cc.Write([]byte(part2)); err != nil { + t.Fatal(err) + } + if err := sc.Close(); err != nil { + t.Fatal(err) + } + + select { case err := <-ch: - if err != nil { + if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match. t.Fatalf("Unexpected error from serveConn: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + case <-time.After(500 * time.Millisecond): + t.Fatal("part2 timeout") } +} - br := bufio.NewReader(&rw.w) - verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, body=\"\", remoteAddr=1.2.3.4:8765") - verifyResponse(t, br, 200, "text/html", "requestURI=/foo2, body=\"abcde\", remoteAddr=1.2.3.4:8765") +func TestStreamRequestBodyExceedMaxSize(t *testing.T) { + t.Parallel() - expectedLogOut := `#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - begin -#0000000100000001 - 1.2.3.4:8765<->1.2.3.4:8765 - GET http://google.com/foo1 - end -#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - begin -#0000000100000002 - 1.2.3.4:8765<->1.2.3.4:8765 - POST http://aaa.com/foo2 - end -` - if cl.out != expectedLogOut { - t.Fatalf("Unexpected logger output: %q. Expected %q", cl.out, expectedLogOut) - } -} + part1 := strings.Repeat("1", 1<<18) + part2 := strings.Repeat("2", 1<<20-1<<18) + contentLength := len(part1) + len(part2) + next := make(chan struct{}) -func TestServerRemoteAddr(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { - h := &ctx.Request.Header - ctx.Success("text/html", []byte(fmt.Sprintf("requestURI=%s, remoteAddr=%s, remoteIP=%s", - h.RequestURI(), ctx.RemoteAddr(), ctx.RemoteIP()))) + checkReader(t, ctx.RequestBodyStream(), part1) + close(next) + checkReader(t, ctx.RequestBodyStream(), part2) }, + DisableKeepalive: true, + StreamRequestBody: true, + MaxRequestBodySize: 1, } - rw := &readWriter{} - rw.r.WriteString("GET /foo1 HTTP/1.1\r\nHost: google.com\r\n\r\n") - - rwx := &readWriterRemoteAddr{ - rw: rw, - addr: &net.TCPAddr{ - IP: []byte{1, 2, 3, 4}, - Port: 8765, - }, + pipe := fasthttputil.NewPipeConns() + cc, sc := pipe.Conn1(), pipe.Conn2() + //write headers and part1 body + if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, part1))); err != nil { + t.Error(err) } ch := make(chan error) go func() { - ch <- s.ServeConn(rwx) + ch <- s.ServeConn(sc) }() select { + case <-next: + case <-time.After(500 * time.Millisecond): + t.Fatal("part1 timeout") + } + + if _, err := cc.Write([]byte(part2)); err != nil { + t.Error(err) + } + + select { case err := <-ch: if err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) + t.Error(err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + case <-time.After(500 * time.Millisecond): + t.Fatal("part2 timeout") } - - br := bufio.NewReader(&rw.w) - verifyResponse(t, br, 200, "text/html", "requestURI=/foo1, remoteAddr=1.2.3.4:8765, remoteIP=1.2.3.4") -} - -type readWriterRemoteAddr struct { - net.Conn - rw io.ReadWriteCloser - addr net.Addr -} - -func (rw *readWriterRemoteAddr) Close() error { - return rw.rw.Close() -} - -func (rw *readWriterRemoteAddr) Read(b []byte) (int, error) { - return rw.rw.Read(b) -} - -func (rw *readWriterRemoteAddr) Write(b []byte) (int, error) { - return rw.rw.Write(b) -} - -func (rw *readWriterRemoteAddr) RemoteAddr() net.Addr { - return rw.addr } -func (rw *readWriterRemoteAddr) LocalAddr() net.Addr { - return rw.addr -} +func TestStreamBodyReqestContentLength(t *testing.T) { + t.Parallel() + content := strings.Repeat("1", 1<<15) // 32K + contentLength := len(content) -func TestServerConnError(t *testing.T) { s := &Server{ Handler: func(ctx *RequestCtx) { - ctx.Error("foobar", 423) + realContentLength := ctx.Request.Header.ContentLength() + if realContentLength != contentLength { + t.Fatal("incorrect content length") + } }, + MaxRequestBodySize: 1 * 1024 * 1024, // 1M + StreamRequestBody: true, } - rw := &readWriter{} - rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n") + pipe := fasthttputil.NewPipeConns() + cc, sc := pipe.Conn1(), pipe.Conn2() + if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, content))); err != nil { + t.Fatal(err) + } ch := make(chan error) go func() { - ch <- s.ServeConn(rw) + ch <- s.ServeConn(sc) }() + if err := sc.Close(); err != nil { + t.Fatal(err) + } + select { case err := <-ch: - if err != nil { + if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match. t.Fatalf("Unexpected error from serveConn: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + case <-time.After(time.Second): + t.Fatal("test timeout") } +} - br := bufio.NewReader(&rw.w) - var resp Response - if err := resp.Read(br); err != nil { - t.Fatalf("Unexpected error when reading response: %s", err) - } - if resp.Header.StatusCode() != 423 { - t.Fatalf("Unexpected status code %d. Expected %d", resp.Header.StatusCode(), 423) - } - if resp.Header.ContentLength() != 6 { - t.Fatalf("Unexpected Content-Length %d. Expected %d", resp.Header.ContentLength(), 6) - } - if !bytes.Equal(resp.Header.Peek("Content-Type"), defaultContentType) { - t.Fatalf("Unexpected Content-Type %q. Expected %q", resp.Header.Peek("Content-Type"), defaultContentType) +func checkReader(t *testing.T, r io.Reader, expected string) { + b := make([]byte, len(expected)) + if _, err := io.ReadFull(r, b); err != nil { + t.Fatalf("Unexpected error from reader: %s", err) } - if !bytes.Equal(resp.Body(), []byte("foobar")) { - t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), "foobar") + if string(b) != expected { + t.Fatal("incorrect request body") } } -func TestServeConnSingleRequest(t *testing.T) { +func TestMaxReadTimeoutPerRequest(t *testing.T) { + t.Parallel() + + headers := []byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", 5*1024)) s := &Server{ Handler: func(ctx *RequestCtx) { - h := &ctx.Request.Header - ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek("Host")))) + t.Error("shouldn't reach handler") + }, + HeaderReceived: func(header *RequestHeader) RequestConfig { + return RequestConfig{ + ReadTimeout: time.Millisecond, + } }, + ReadBufferSize: len(headers), + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, } - rw := &readWriter{} - rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\n") - + pipe := fasthttputil.NewPipeConns() + cc, sc := pipe.Conn1(), pipe.Conn2() + go func() { + //write headers + _, err := cc.Write(headers) + if err != nil { + t.Error(err) + } + //write body + for i := 0; i < 5*1024; i++ { + time.Sleep(time.Millisecond) + cc.Write([]byte{'a'}) //nolint:errcheck + } + }() ch := make(chan error) go func() { - ch <- s.ServeConn(rw) + ch <- s.ServeConn(sc) }() select { case err := <-ch: - if err != nil { + if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") { t.Fatalf("Unexpected error from serveConn: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + case <-time.After(time.Second): + t.Fatal("test timeout") } - - br := bufio.NewReader(&rw.w) - verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com") } -func TestServeConnMultiRequests(t *testing.T) { +func TestMaxWriteTimeoutPerRequest(t *testing.T) { + t.Parallel() + + headers := []byte("GET /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Type: aa\r\n\r\n") s := &Server{ Handler: func(ctx *RequestCtx) { - h := &ctx.Request.Header - ctx.Success("aaa", []byte(fmt.Sprintf("requestURI=%s, host=%s", h.RequestURI(), h.Peek("Host")))) + ctx.SetBodyStreamWriter(func(w *bufio.Writer) { + var buf [192]byte + for { + w.Write(buf[:]) //nolint:errcheck + } + }) + }, + HeaderReceived: func(header *RequestHeader) RequestConfig { + return RequestConfig{ + WriteTimeout: time.Millisecond, + } }, + ReadBufferSize: 192, + ReadTimeout: time.Second * 5, + WriteTimeout: time.Second * 5, } - rw := &readWriter{} - rw.r.WriteString("GET /foo/bar?baz HTTP/1.1\r\nHost: google.com\r\n\r\nGET /abc HTTP/1.1\r\nHost: foobar.com\r\n\r\n") + pipe := fasthttputil.NewPipeConns() + cc, sc := pipe.Conn1(), pipe.Conn2() + + var resp Response + go func() { + //write headers + _, err := cc.Write(headers) + if err != nil { + t.Error(err) + } + br := bufio.NewReaderSize(cc, 192) + err = resp.Header.Read(br) + if err != nil { + t.Error(err) + } + var chunk [192]byte + for { + time.Sleep(time.Millisecond) + br.Read(chunk[:]) //nolint:errcheck + } + }() ch := make(chan error) go func() { - ch <- s.ServeConn(rw) + ch <- s.ServeConn(sc) }() select { case err := <-ch: - if err != nil { + if err == nil || err != nil && !strings.EqualFold(err.Error(), "timeout") { t.Fatalf("Unexpected error from serveConn: %s", err) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("timeout") + case <-time.After(time.Second): + t.Fatal("test timeout") } +} - br := bufio.NewReader(&rw.w) - verifyResponse(t, br, 200, "aaa", "requestURI=/foo/bar?baz, host=google.com") - verifyResponse(t, br, 200, "aaa", "requestURI=/abc, host=foobar.com") +func TestIncompleteBodyReturnsUnexpectedEOF(t *testing.T) { + t.Parallel() + + rw := &readWriter{} + rw.r.WriteString("POST /foo HTTP/1.1\r\nHost: google.com\r\nContent-Length: 5\r\n\r\n123") + s := &Server{ + Handler: func(ctx *RequestCtx) {}, + } + ch := make(chan error) + go func() { + ch <- s.ServeConn(rw) + }() + if err := <-ch; err == nil || err.Error() != "unexpected EOF" { + t.Fatal(err) + } } -func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) { +func verifyResponse(t *testing.T, r *bufio.Reader, expectedStatusCode int, expectedContentType, expectedBody string) *Response { var resp Response if err := resp.Read(r); err != nil { t.Fatalf("Unexpected error when parsing response: %s", err) @@ -2314,6 +3706,7 @@ t.Fatalf("Unexpected body %q. Expected %q", resp.Body(), []byte(expectedBody)) } verifyResponseHeader(t, &resp.Header, expectedStatusCode, len(resp.Body()), expectedContentType) + return &resp } type readWriter struct { @@ -2342,6 +3735,10 @@ return zeroTCPAddr } +func (rw *readWriter) SetDeadline(t time.Time) error { + return nil +} + func (rw *readWriter) SetReadDeadline(t time.Time) error { return nil } @@ -2349,3 +3746,14 @@ func (rw *readWriter) SetWriteDeadline(t time.Time) error { return nil } + +type testLogger struct { + lock sync.Mutex + out string +} + +func (cl *testLogger) Printf(format string, args ...interface{}) { + cl.lock.Lock() + cl.out += fmt.Sprintf(format, args...)[6:] + "\n" + cl.lock.Unlock() +} diff -Nru golang-github-valyala-fasthttp-20160617/server_timing_test.go golang-github-valyala-fasthttp-1.31.0/server_timing_test.go --- golang-github-valyala-fasthttp-20160617/server_timing_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/server_timing_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -16,6 +16,16 @@ var defaultClientsCount = runtime.NumCPU() +func BenchmarkRequestCtxRedirect(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + var ctx RequestCtx + for pb.Next() { + ctx.Request.SetRequestURI("http://aaa.com/fff/ss.html?sdf") + ctx.Redirect("/foo/bar?baz=111", StatusFound) + } + }) +} + func BenchmarkServerGet1ReqPerConn(b *testing.B) { benchmarkServerGet(b, defaultClientsCount, 1) } @@ -337,15 +347,15 @@ ch := make(chan struct{}, b.N) s := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.Method != "GET" { + if req.Method != MethodGet { b.Fatalf("Unexpected request method: %s", req.Method) } h := w.Header() h.Set("Content-Type", "text/plain") if requestsPerConn == 1 { - h.Set("Connection", "close") + h.Set(HeaderConnection, "close") } - w.Write(fakeResponse) + w.Write(fakeResponse) //nolint:errcheck registerServedRequest(b, ch) }), } @@ -380,7 +390,7 @@ ch := make(chan struct{}, b.N) s := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { + if req.Method != MethodPost { b.Fatalf("Unexpected request method: %s", req.Method) } body, err := ioutil.ReadAll(req.Body) @@ -394,9 +404,9 @@ h := w.Header() h.Set("Content-Type", "text/plain") if requestsPerConn == 1 { - h.Set("Connection", "close") + h.Set(HeaderConnection, "close") } - w.Write(body) + w.Write(body) //nolint:errcheck registerServedRequest(b, ch) }), } @@ -437,7 +447,7 @@ ln := newFakeListener(b.N, clientsCount, requestsPerConn, request) ch := make(chan struct{}) go func() { - s.Serve(ln) + s.Serve(ln) //nolint:errcheck ch <- struct{}{} }() diff -Nru golang-github-valyala-fasthttp-20160617/ssl-cert-snakeoil.key golang-github-valyala-fasthttp-1.31.0/ssl-cert-snakeoil.key --- golang-github-valyala-fasthttp-20160617/ssl-cert-snakeoil.key 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/ssl-cert-snakeoil.key 1970-01-01 00:00:00.000000000 +0000 @@ -1,28 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQD4IQusAs8PJdnG -3mURt/AXtgC+ceqLOatJ49JJE1VPTkMAy+oE1f1XvkMrYsHqmDf6GWVzgVXryL4U -wq2/nJSm56ddhN55nI8oSN3dtywUB8/ShelEN73nlN77PeD9tl6NksPwWaKrqxq0 -FlabRPZSQCfmgZbhDV8Sa8mfCkFU0G0lit6kLGceCKMvmW+9Bz7ebsYmVdmVMxmf -IJStFD44lWFTdUc65WISKEdW2ELcUefb0zOLw+0PCbXFGJH5x5ktksW8+BBk2Hkg -GeQRL/qPCccthbScO0VgNj3zJ3ZZL0ObSDAbvNDG85joeNjDNq5DT/BAZ0bOSbEF -sh+f9BAzAgMBAAECggEBAJWv2cq7Jw6MVwSRxYca38xuD6TUNBopgBvjREixURW2 -sNUaLuMb9Omp7fuOaE2N5rcJ+xnjPGIxh/oeN5MQctz9gwn3zf6vY+15h97pUb4D -uGvYPRDaT8YVGS+X9NMZ4ZCmqW2lpWzKnCFoGHcy8yZLbcaxBsRdvKzwOYGoPiFb -K2QuhXZ/1UPmqK9i2DFKtj40X6vBszTNboFxOVpXrPu0FJwLVSDf2hSZ4fMM0DH3 -YqwKcYf5te+hxGKgrqRA3tn0NCWii0in6QIwXMC+kMw1ebg/tZKqyDLMNptAK8J+ -DVw9m5X1seUHS5ehU/g2jrQrtK5WYn7MrFK4lBzlRwECgYEA/d1TeANYECDWRRDk -B0aaRZs87Rwl/J9PsvbsKvtU/bX+OfSOUjOa9iQBqn0LmU8GqusEET/QVUfocVwV -Bggf/5qDLxz100Rj0ags/yE/kNr0Bb31kkkKHFMnCT06YasR7qKllwrAlPJvQv9x -IzBKq+T/Dx08Wep9bCRSFhzRCnsCgYEA+jdeZXTDr/Vz+D2B3nAw1frqYFfGnEVY -wqmoK3VXMDkGuxsloO2rN+SyiUo3JNiQNPDub/t7175GH5pmKtZOlftePANsUjBj -wZ1D0rI5Bxu/71ibIUYIRVmXsTEQkh/ozoh3jXCZ9+bLgYiYx7789IUZZSokFQ3D -FICUT9KJ36kCgYAGoq9Y1rWJjmIrYfqj2guUQC+CfxbbGIrrwZqAsRsSmpwvhZ3m -tiSZxG0quKQB+NfSxdvQW5ulbwC7Xc3K35F+i9pb8+TVBdeaFkw+yu6vaZmxQLrX -fQM/pEjD7A7HmMIaO7QaU5SfEAsqdCTP56Y8AftMuNXn/8IRfo2KuGwaWwKBgFpU -ILzJoVdlad9E/Rw7LjYhZfkv1uBVXIyxyKcfrkEXZSmozDXDdxsvcZCEfVHM6Ipk -K/+7LuMcqp4AFEAEq8wTOdq6daFaHLkpt/FZK6M4TlruhtpFOPkoNc3e45eM83OT -6mziKINJC1CQ6m65sQHpBtjxlKMRG8rL/D6wx9s5AoGBAMRlqNPMwglT3hvDmsAt -9Lf9pdmhERUlHhD8bj8mDaBj2Aqv7f6VRJaYZqP403pKKQexuqcn80mtjkSAPFkN -Cj7BVt/RXm5uoxDTnfi26RF9F6yNDEJ7UU9+peBr99aazF/fTgW/1GcMkQnum8uV -c257YgaWmjK9uB0Y2r2VxS0G ------END PRIVATE KEY----- diff -Nru golang-github-valyala-fasthttp-20160617/ssl-cert-snakeoil.pem golang-github-valyala-fasthttp-1.31.0/ssl-cert-snakeoil.pem --- golang-github-valyala-fasthttp-20160617/ssl-cert-snakeoil.pem 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/ssl-cert-snakeoil.pem 1970-01-01 00:00:00.000000000 +0000 @@ -1,17 +0,0 @@ ------BEGIN CERTIFICATE----- -MIICujCCAaKgAwIBAgIJAMbXnKZ/cikUMA0GCSqGSIb3DQEBCwUAMBUxEzARBgNV -BAMTCnVidW50dS5uYW4wHhcNMTUwMjA0MDgwMTM5WhcNMjUwMjAxMDgwMTM5WjAV -MRMwEQYDVQQDEwp1YnVudHUubmFuMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB -CgKCAQEA+CELrALPDyXZxt5lEbfwF7YAvnHqizmrSePSSRNVT05DAMvqBNX9V75D -K2LB6pg3+hllc4FV68i+FMKtv5yUpuenXYTeeZyPKEjd3bcsFAfP0oXpRDe955Te -+z3g/bZejZLD8Fmiq6satBZWm0T2UkAn5oGW4Q1fEmvJnwpBVNBtJYrepCxnHgij -L5lvvQc+3m7GJlXZlTMZnyCUrRQ+OJVhU3VHOuViEihHVthC3FHn29Mzi8PtDwm1 -xRiR+ceZLZLFvPgQZNh5IBnkES/6jwnHLYW0nDtFYDY98yd2WS9Dm0gwG7zQxvOY -6HjYwzauQ0/wQGdGzkmxBbIfn/QQMwIDAQABow0wCzAJBgNVHRMEAjAAMA0GCSqG -SIb3DQEBCwUAA4IBAQBQjKm/4KN/iTgXbLTL3i7zaxYXFLXsnT1tF+ay4VA8aj98 -L3JwRTciZ3A5iy/W4VSCt3eASwOaPWHKqDBB5RTtL73LoAqsWmO3APOGQAbixcQ2 -45GXi05OKeyiYRi1Nvq7Unv9jUkRDHUYVPZVSAjCpsXzPhFkmZoTRxmx5l0ZF7Li -K91lI5h+eFq0dwZwrmlPambyh1vQUi70VHv8DNToVU29kel7YLbxGbuqETfhrcy6 -X+Mha6RYITkAn5FqsZcKMsc9eYGEF4l3XV+oS7q6xfTxktYJMFTI18J0lQ2Lv/CI -whdMnYGntDQBE/iFCrJEGNsKGc38796GBOb5j+zd ------END CERTIFICATE----- diff -Nru golang-github-valyala-fasthttp-20160617/stackless/doc.go golang-github-valyala-fasthttp-1.31.0/stackless/doc.go --- golang-github-valyala-fasthttp-20160617/stackless/doc.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/stackless/doc.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,3 @@ +// Package stackless provides functionality that may save stack space +// for high number of concurrently running goroutines. +package stackless diff -Nru golang-github-valyala-fasthttp-20160617/stackless/func.go golang-github-valyala-fasthttp-1.31.0/stackless/func.go --- golang-github-valyala-fasthttp-20160617/stackless/func.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/stackless/func.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,79 @@ +package stackless + +import ( + "runtime" + "sync" +) + +// NewFunc returns stackless wrapper for the function f. +// +// Unlike f, the returned stackless wrapper doesn't use stack space +// on the goroutine that calls it. +// The wrapper may save a lot of stack space if the following conditions +// are met: +// +// - f doesn't contain blocking calls on network, I/O or channels; +// - f uses a lot of stack space; +// - the wrapper is called from high number of concurrent goroutines. +// +// The stackless wrapper returns false if the call cannot be processed +// at the moment due to high load. +func NewFunc(f func(ctx interface{})) func(ctx interface{}) bool { + if f == nil { + panic("BUG: f cannot be nil") + } + + funcWorkCh := make(chan *funcWork, runtime.GOMAXPROCS(-1)*2048) + onceInit := func() { + n := runtime.GOMAXPROCS(-1) + for i := 0; i < n; i++ { + go funcWorker(funcWorkCh, f) + } + } + var once sync.Once + + return func(ctx interface{}) bool { + once.Do(onceInit) + fw := getFuncWork() + fw.ctx = ctx + + select { + case funcWorkCh <- fw: + default: + putFuncWork(fw) + return false + } + <-fw.done + putFuncWork(fw) + return true + } +} + +func funcWorker(funcWorkCh <-chan *funcWork, f func(ctx interface{})) { + for fw := range funcWorkCh { + f(fw.ctx) + fw.done <- struct{}{} + } +} + +func getFuncWork() *funcWork { + v := funcWorkPool.Get() + if v == nil { + v = &funcWork{ + done: make(chan struct{}, 1), + } + } + return v.(*funcWork) +} + +func putFuncWork(fw *funcWork) { + fw.ctx = nil + funcWorkPool.Put(fw) +} + +var funcWorkPool sync.Pool + +type funcWork struct { + ctx interface{} + done chan struct{} +} diff -Nru golang-github-valyala-fasthttp-20160617/stackless/func_test.go golang-github-valyala-fasthttp-1.31.0/stackless/func_test.go --- golang-github-valyala-fasthttp-20160617/stackless/func_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/stackless/func_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,90 @@ +package stackless + +import ( + "fmt" + "sync/atomic" + "testing" + "time" +) + +func TestNewFuncSimple(t *testing.T) { + t.Parallel() + + var n uint64 + f := NewFunc(func(ctx interface{}) { + atomic.AddUint64(&n, uint64(ctx.(int))) + }) + + iterations := 4 * 1024 + for i := 0; i < iterations; i++ { + if !f(2) { + t.Fatalf("f mustn't return false") + } + } + if n != uint64(2*iterations) { + t.Fatalf("Unexpected n: %d. Expecting %d", n, 2*iterations) + } +} + +func TestNewFuncMulti(t *testing.T) { + t.Parallel() + + var n1, n2 uint64 + f1 := NewFunc(func(ctx interface{}) { + atomic.AddUint64(&n1, uint64(ctx.(int))) + }) + f2 := NewFunc(func(ctx interface{}) { + atomic.AddUint64(&n2, uint64(ctx.(int))) + }) + + iterations := 4 * 1024 + + f1Done := make(chan error, 1) + go func() { + var err error + for i := 0; i < iterations; i++ { + if !f1(3) { + err = fmt.Errorf("f1 mustn't return false") + break + } + } + f1Done <- err + }() + + f2Done := make(chan error, 1) + go func() { + var err error + for i := 0; i < iterations; i++ { + if !f2(5) { + err = fmt.Errorf("f2 mustn't return false") + break + } + } + f2Done <- err + }() + + select { + case err := <-f1Done: + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + select { + case err := <-f2Done: + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + case <-time.After(time.Second): + t.Fatalf("timeout") + } + + if n1 != uint64(3*iterations) { + t.Fatalf("unexpected n1: %d. Expecting %d", n1, 3*iterations) + } + if n2 != uint64(5*iterations) { + t.Fatalf("unexpected n2: %d. Expecting %d", n2, 5*iterations) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/stackless/func_timing_test.go golang-github-valyala-fasthttp-1.31.0/stackless/func_timing_test.go --- golang-github-valyala-fasthttp-20160617/stackless/func_timing_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/stackless/func_timing_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,40 @@ +package stackless + +import ( + "sync/atomic" + "testing" +) + +func BenchmarkFuncOverhead(b *testing.B) { + var n uint64 + f := NewFunc(func(ctx interface{}) { + atomic.AddUint64(&n, *(ctx.(*uint64))) + }) + b.RunParallel(func(pb *testing.PB) { + x := uint64(1) + for pb.Next() { + if !f(&x) { + b.Fatalf("f mustn't return false") + } + } + }) + if n != uint64(b.N) { + b.Fatalf("unexected n: %d. Expecting %d", n, b.N) + } +} + +func BenchmarkFuncPure(b *testing.B) { + var n uint64 + f := func(x *uint64) { + atomic.AddUint64(&n, *x) + } + b.RunParallel(func(pb *testing.PB) { + x := uint64(1) + for pb.Next() { + f(&x) + } + }) + if n != uint64(b.N) { + b.Fatalf("unexected n: %d. Expecting %d", n, b.N) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/stackless/writer.go golang-github-valyala-fasthttp-1.31.0/stackless/writer.go --- golang-github-valyala-fasthttp-20160617/stackless/writer.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/stackless/writer.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,138 @@ +package stackless + +import ( + "errors" + "fmt" + "io" + + "github.com/valyala/bytebufferpool" +) + +// Writer is an interface stackless writer must conform to. +// +// The interface contains common subset for Writers from compress/* packages. +type Writer interface { + Write(p []byte) (int, error) + Flush() error + Close() error + Reset(w io.Writer) +} + +// NewWriterFunc must return new writer that will be wrapped into +// stackless writer. +type NewWriterFunc func(w io.Writer) Writer + +// NewWriter creates a stackless writer around a writer returned +// from newWriter. +// +// The returned writer writes data to dstW. +// +// Writers that use a lot of stack space may be wrapped into stackless writer, +// thus saving stack space for high number of concurrently running goroutines. +func NewWriter(dstW io.Writer, newWriter NewWriterFunc) Writer { + w := &writer{ + dstW: dstW, + } + w.zw = newWriter(&w.xw) + return w +} + +type writer struct { + dstW io.Writer + zw Writer + xw xWriter + + err error + n int + + p []byte + op op +} + +type op int + +const ( + opWrite op = iota + opFlush + opClose + opReset +) + +func (w *writer) Write(p []byte) (int, error) { + w.p = p + err := w.do(opWrite) + w.p = nil + return w.n, err +} + +func (w *writer) Flush() error { + return w.do(opFlush) +} + +func (w *writer) Close() error { + return w.do(opClose) +} + +func (w *writer) Reset(dstW io.Writer) { + w.xw.Reset() + w.do(opReset) //nolint:errcheck + w.dstW = dstW +} + +func (w *writer) do(op op) error { + w.op = op + if !stacklessWriterFunc(w) { + return errHighLoad + } + err := w.err + if err != nil { + return err + } + if w.xw.bb != nil && len(w.xw.bb.B) > 0 { + _, err = w.dstW.Write(w.xw.bb.B) + } + w.xw.Reset() + + return err +} + +var errHighLoad = errors.New("cannot compress data due to high load") + +var stacklessWriterFunc = NewFunc(writerFunc) + +func writerFunc(ctx interface{}) { + w := ctx.(*writer) + switch w.op { + case opWrite: + w.n, w.err = w.zw.Write(w.p) + case opFlush: + w.err = w.zw.Flush() + case opClose: + w.err = w.zw.Close() + case opReset: + w.zw.Reset(&w.xw) + w.err = nil + default: + panic(fmt.Sprintf("BUG: unexpected op: %d", w.op)) + } +} + +type xWriter struct { + bb *bytebufferpool.ByteBuffer +} + +func (w *xWriter) Write(p []byte) (int, error) { + if w.bb == nil { + w.bb = bufferPool.Get() + } + return w.bb.Write(p) +} + +func (w *xWriter) Reset() { + if w.bb != nil { + bufferPool.Put(w.bb) + w.bb = nil + } +} + +var bufferPool bytebufferpool.Pool diff -Nru golang-github-valyala-fasthttp-20160617/stackless/writer_test.go golang-github-valyala-fasthttp-1.31.0/stackless/writer_test.go --- golang-github-valyala-fasthttp-20160617/stackless/writer_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/stackless/writer_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,130 @@ +package stackless + +import ( + "bytes" + "compress/flate" + "compress/gzip" + "fmt" + "io" + "io/ioutil" + "testing" + "time" +) + +func TestCompressFlateSerial(t *testing.T) { + t.Parallel() + + if err := testCompressFlate(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestCompressFlateConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(testCompressFlate, 10); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func testCompressFlate() error { + return testWriter(func(w io.Writer) Writer { + zw, err := flate.NewWriter(w, flate.DefaultCompression) + if err != nil { + panic(fmt.Sprintf("BUG: unexpected error: %s", err)) + } + return zw + }, func(r io.Reader) io.Reader { + return flate.NewReader(r) + }) +} + +func TestCompressGzipSerial(t *testing.T) { + t.Parallel() + + if err := testCompressGzip(); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestCompressGzipConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(testCompressGzip, 10); err != nil { + t.Fatalf("unexpected error: %s", err) + } +} + +func testCompressGzip() error { + return testWriter(func(w io.Writer) Writer { + return gzip.NewWriter(w) + }, func(r io.Reader) io.Reader { + zr, err := gzip.NewReader(r) + if err != nil { + panic(fmt.Sprintf("BUG: cannot create gzip reader: %s", err)) + } + return zr + }) +} + +func testWriter(newWriter NewWriterFunc, newReader func(io.Reader) io.Reader) error { + dstW := &bytes.Buffer{} + w := NewWriter(dstW, newWriter) + + for i := 0; i < 5; i++ { + if err := testWriterReuse(w, dstW, newReader); err != nil { + return fmt.Errorf("unexpected error when re-using writer on iteration %d: %s", i, err) + } + dstW = &bytes.Buffer{} + w.Reset(dstW) + } + + return nil +} + +func testWriterReuse(w Writer, r io.Reader, newReader func(io.Reader) io.Reader) error { + wantW := &bytes.Buffer{} + mw := io.MultiWriter(w, wantW) + for i := 0; i < 30; i++ { + fmt.Fprintf(mw, "foobar %d\n", i) + if i%13 == 0 { + if err := w.Flush(); err != nil { + return fmt.Errorf("error on flush: %s", err) + } + } + } + w.Close() + + zr := newReader(r) + data, err := ioutil.ReadAll(zr) + if err != nil { + return fmt.Errorf("unexpected error: %s, data=%q", err, data) + } + + wantData := wantW.Bytes() + if !bytes.Equal(data, wantData) { + return fmt.Errorf("unexpected data: %q. Expecting %q", data, wantData) + } + + return nil +} + +func testConcurrent(testFunc func() error, concurrency int) error { + ch := make(chan error, concurrency) + for i := 0; i < concurrency; i++ { + go func() { + ch <- testFunc() + }() + } + for i := 0; i < concurrency; i++ { + select { + case err := <-ch: + if err != nil { + return fmt.Errorf("unexpected error on goroutine %d: %s", i, err) + } + case <-time.After(time.Second): + return fmt.Errorf("timeout on goroutine %d", i) + } + } + return nil +} diff -Nru golang-github-valyala-fasthttp-20160617/status.go golang-github-valyala-fasthttp-1.31.0/status.go --- golang-github-valyala-fasthttp-20160617/status.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/status.go 2021-10-09 18:39:05.000000000 +0000 @@ -2,76 +2,103 @@ import ( "fmt" - "sync/atomic" + "strconv" ) -// HTTP status codes were stolen from net/http. const ( - StatusContinue = 100 - StatusSwitchingProtocols = 101 + statusMessageMin = 100 + statusMessageMax = 511 +) - StatusOK = 200 - StatusCreated = 201 - StatusAccepted = 202 - StatusNonAuthoritativeInfo = 203 - StatusNoContent = 204 - StatusResetContent = 205 - StatusPartialContent = 206 - - StatusMultipleChoices = 300 - StatusMovedPermanently = 301 - StatusFound = 302 - StatusSeeOther = 303 - StatusNotModified = 304 - StatusUseProxy = 305 - StatusTemporaryRedirect = 307 - - StatusBadRequest = 400 - StatusUnauthorized = 401 - StatusPaymentRequired = 402 - StatusForbidden = 403 - StatusNotFound = 404 - StatusMethodNotAllowed = 405 - StatusNotAcceptable = 406 - StatusProxyAuthRequired = 407 - StatusRequestTimeout = 408 - StatusConflict = 409 - StatusGone = 410 - StatusLengthRequired = 411 - StatusPreconditionFailed = 412 - StatusRequestEntityTooLarge = 413 - StatusRequestURITooLong = 414 - StatusUnsupportedMediaType = 415 - StatusRequestedRangeNotSatisfiable = 416 - StatusExpectationFailed = 417 - StatusTeapot = 418 - StatusPreconditionRequired = 428 - StatusTooManyRequests = 429 - StatusRequestHeaderFieldsTooLarge = 431 - - StatusInternalServerError = 500 - StatusNotImplemented = 501 - StatusBadGateway = 502 - StatusServiceUnavailable = 503 - StatusGatewayTimeout = 504 - StatusHTTPVersionNotSupported = 505 - StatusNetworkAuthenticationRequired = 511 +// HTTP status codes were stolen from net/http. +const ( + StatusContinue = 100 // RFC 7231, 6.2.1 + StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2 + StatusProcessing = 102 // RFC 2518, 10.1 + StatusEarlyHints = 103 // RFC 8297 + + StatusOK = 200 // RFC 7231, 6.3.1 + StatusCreated = 201 // RFC 7231, 6.3.2 + StatusAccepted = 202 // RFC 7231, 6.3.3 + StatusNonAuthoritativeInfo = 203 // RFC 7231, 6.3.4 + StatusNoContent = 204 // RFC 7231, 6.3.5 + StatusResetContent = 205 // RFC 7231, 6.3.6 + StatusPartialContent = 206 // RFC 7233, 4.1 + StatusMultiStatus = 207 // RFC 4918, 11.1 + StatusAlreadyReported = 208 // RFC 5842, 7.1 + StatusIMUsed = 226 // RFC 3229, 10.4.1 + + StatusMultipleChoices = 300 // RFC 7231, 6.4.1 + StatusMovedPermanently = 301 // RFC 7231, 6.4.2 + StatusFound = 302 // RFC 7231, 6.4.3 + StatusSeeOther = 303 // RFC 7231, 6.4.4 + StatusNotModified = 304 // RFC 7232, 4.1 + StatusUseProxy = 305 // RFC 7231, 6.4.5 + _ = 306 // RFC 7231, 6.4.6 (Unused) + StatusTemporaryRedirect = 307 // RFC 7231, 6.4.7 + StatusPermanentRedirect = 308 // RFC 7538, 3 + + StatusBadRequest = 400 // RFC 7231, 6.5.1 + StatusUnauthorized = 401 // RFC 7235, 3.1 + StatusPaymentRequired = 402 // RFC 7231, 6.5.2 + StatusForbidden = 403 // RFC 7231, 6.5.3 + StatusNotFound = 404 // RFC 7231, 6.5.4 + StatusMethodNotAllowed = 405 // RFC 7231, 6.5.5 + StatusNotAcceptable = 406 // RFC 7231, 6.5.6 + StatusProxyAuthRequired = 407 // RFC 7235, 3.2 + StatusRequestTimeout = 408 // RFC 7231, 6.5.7 + StatusConflict = 409 // RFC 7231, 6.5.8 + StatusGone = 410 // RFC 7231, 6.5.9 + StatusLengthRequired = 411 // RFC 7231, 6.5.10 + StatusPreconditionFailed = 412 // RFC 7232, 4.2 + StatusRequestEntityTooLarge = 413 // RFC 7231, 6.5.11 + StatusRequestURITooLong = 414 // RFC 7231, 6.5.12 + StatusUnsupportedMediaType = 415 // RFC 7231, 6.5.13 + StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4 + StatusExpectationFailed = 417 // RFC 7231, 6.5.14 + StatusTeapot = 418 // RFC 7168, 2.3.3 + StatusMisdirectedRequest = 421 // RFC 7540, 9.1.2 + StatusUnprocessableEntity = 422 // RFC 4918, 11.2 + StatusLocked = 423 // RFC 4918, 11.3 + StatusFailedDependency = 424 // RFC 4918, 11.4 + StatusUpgradeRequired = 426 // RFC 7231, 6.5.15 + StatusPreconditionRequired = 428 // RFC 6585, 3 + StatusTooManyRequests = 429 // RFC 6585, 4 + StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5 + StatusUnavailableForLegalReasons = 451 // RFC 7725, 3 + + StatusInternalServerError = 500 // RFC 7231, 6.6.1 + StatusNotImplemented = 501 // RFC 7231, 6.6.2 + StatusBadGateway = 502 // RFC 7231, 6.6.3 + StatusServiceUnavailable = 503 // RFC 7231, 6.6.4 + StatusGatewayTimeout = 504 // RFC 7231, 6.6.5 + StatusHTTPVersionNotSupported = 505 // RFC 7231, 6.6.6 + StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1 + StatusInsufficientStorage = 507 // RFC 4918, 11.5 + StatusLoopDetected = 508 // RFC 5842, 7.2 + StatusNotExtended = 510 // RFC 2774, 7 + StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6 ) var ( - statusLines atomic.Value + statusLines = make([][]byte, statusMessageMax+1) - statusMessages = map[int]string{ + statusMessages = []string{ StatusContinue: "Continue", - StatusSwitchingProtocols: "SwitchingProtocols", + StatusSwitchingProtocols: "Switching Protocols", + StatusProcessing: "Processing", + StatusEarlyHints: "Early Hints", StatusOK: "OK", StatusCreated: "Created", StatusAccepted: "Accepted", - StatusNonAuthoritativeInfo: "Non-Authoritative Info", + StatusNonAuthoritativeInfo: "Non-Authoritative Information", StatusNoContent: "No Content", StatusResetContent: "Reset Content", StatusPartialContent: "Partial Content", + StatusMultiStatus: "Multi-Status", + StatusAlreadyReported: "Already Reported", + StatusIMUsed: "IM Used", StatusMultipleChoices: "Multiple Choices", StatusMovedPermanently: "Moved Permanently", @@ -80,6 +107,7 @@ StatusNotModified: "Not Modified", StatusUseProxy: "Use Proxy", StatusTemporaryRedirect: "Temporary Redirect", + StatusPermanentRedirect: "Permanent Redirect", StatusBadRequest: "Bad Request", StatusUnauthorized: "Unauthorized", @@ -88,7 +116,7 @@ StatusNotFound: "Not Found", StatusMethodNotAllowed: "Method Not Allowed", StatusNotAcceptable: "Not Acceptable", - StatusProxyAuthRequired: "Proxy Auth Required", + StatusProxyAuthRequired: "Proxy Authentication Required", StatusRequestTimeout: "Request Timeout", StatusConflict: "Conflict", StatusGone: "Gone", @@ -99,10 +127,16 @@ StatusUnsupportedMediaType: "Unsupported Media Type", StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", StatusExpectationFailed: "Expectation Failed", - StatusTeapot: "Teapot", + StatusTeapot: "I'm a teapot", + StatusMisdirectedRequest: "Misdirected Request", + StatusUnprocessableEntity: "Unprocessable Entity", + StatusLocked: "Locked", + StatusFailedDependency: "Failed Dependency", + StatusUpgradeRequired: "Upgrade Required", StatusPreconditionRequired: "Precondition Required", StatusTooManyRequests: "Too Many Requests", - StatusRequestHeaderFieldsTooLarge: "Request HeaderFields Too Large", + StatusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large", + StatusUnavailableForLegalReasons: "Unavailable For Legal Reasons", StatusInternalServerError: "Internal Server Error", StatusNotImplemented: "Not Implemented", @@ -110,12 +144,20 @@ StatusServiceUnavailable: "Service Unavailable", StatusGatewayTimeout: "Gateway Timeout", StatusHTTPVersionNotSupported: "HTTP Version Not Supported", + StatusVariantAlsoNegotiates: "Variant Also Negotiates", + StatusInsufficientStorage: "Insufficient Storage", + StatusLoopDetected: "Loop Detected", + StatusNotExtended: "Not Extended", StatusNetworkAuthenticationRequired: "Network Authentication Required", } ) // StatusMessage returns HTTP status message for the given status code. func StatusMessage(statusCode int) string { + if statusCode < statusMessageMin || statusCode > statusMessageMax { + return "Unknown Status Code" + } + s := statusMessages[statusCode] if s == "" { s = "Unknown Status Code" @@ -124,24 +166,28 @@ } func init() { - statusLines.Store(make(map[int][]byte)) + // Fill all valid status lines + for i := 0; i < len(statusLines); i++ { + statusLines[i] = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n", i, StatusMessage(i))) + } } func statusLine(statusCode int) []byte { - m := statusLines.Load().(map[int][]byte) - h := m[statusCode] - if h != nil { - return h + if statusCode < 0 || statusCode > statusMessageMax { + return invalidStatusLine(statusCode) } - statusText := StatusMessage(statusCode) + return statusLines[statusCode] +} - h = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n", statusCode, statusText)) - newM := make(map[int][]byte, len(m)+1) - for k, v := range m { - newM[k] = v - } - newM[statusCode] = h - statusLines.Store(newM) - return h +func invalidStatusLine(statusCode int) []byte { + statusText := StatusMessage(statusCode) + // xxx placeholder of status code + var line = make([]byte, 0, len("HTTP/1.1 xxx \r\n")+len(statusText)) + line = append(line, "HTTP/1.1 "...) + line = strconv.AppendInt(line, int64(statusCode), 10) + line = append(line, ' ') + line = append(line, statusText...) + line = append(line, "\r\n"...) + return line } diff -Nru golang-github-valyala-fasthttp-20160617/status_test.go golang-github-valyala-fasthttp-1.31.0/status_test.go --- golang-github-valyala-fasthttp-20160617/status_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/status_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,24 @@ +package fasthttp + +import ( + "bytes" + "testing" +) + +func TestStatusLine(t *testing.T) { + t.Parallel() + + testStatusLine(t, -1, []byte("HTTP/1.1 -1 Unknown Status Code\r\n")) + testStatusLine(t, 99, []byte("HTTP/1.1 99 Unknown Status Code\r\n")) + testStatusLine(t, 200, []byte("HTTP/1.1 200 OK\r\n")) + testStatusLine(t, 512, []byte("HTTP/1.1 512 Unknown Status Code\r\n")) + testStatusLine(t, 512, []byte("HTTP/1.1 512 Unknown Status Code\r\n")) + testStatusLine(t, 520, []byte("HTTP/1.1 520 Unknown Status Code\r\n")) +} + +func testStatusLine(t *testing.T, statusCode int, expected []byte) { + line := statusLine(statusCode) + if !bytes.Equal(expected, line) { + t.Fatalf("unexpected status line %s. Expecting %s", string(line), string(expected)) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/status_timing_test.go golang-github-valyala-fasthttp-1.31.0/status_timing_test.go --- golang-github-valyala-fasthttp-20160617/status_timing_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/status_timing_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,29 @@ +package fasthttp + +import ( + "bytes" + "testing" +) + +func BenchmarkStatusLine99(b *testing.B) { + benchmarkStatusLine(b, 99, []byte("HTTP/1.1 99 Unknown Status Code\r\n")) +} + +func BenchmarkStatusLine200(b *testing.B) { + benchmarkStatusLine(b, 200, []byte("HTTP/1.1 200 OK\r\n")) +} + +func BenchmarkStatusLine512(b *testing.B) { + benchmarkStatusLine(b, 512, []byte("HTTP/1.1 512 Unknown Status Code\r\n")) +} + +func benchmarkStatusLine(b *testing.B, statusCode int, expected []byte) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + line := statusLine(statusCode) + if !bytes.Equal(expected, line) { + b.Fatalf("unexpected status line %s. Expecting %s", string(line), string(expected)) + } + } + }) +} diff -Nru golang-github-valyala-fasthttp-20160617/stream.go golang-github-valyala-fasthttp-1.31.0/stream.go --- golang-github-valyala-fasthttp-20160617/stream.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/stream.go 2021-10-09 18:39:05.000000000 +0000 @@ -3,7 +3,6 @@ import ( "bufio" "io" - "runtime/debug" "sync" "github.com/valyala/fasthttp/fasthttputil" @@ -42,12 +41,6 @@ } go func() { - defer func() { - if r := recover(); r != nil { - defaultLogger.Printf("panic in StreamWriter: %s\nStack trace:\n%s", r, debug.Stack()) - } - }() - sw(bw) bw.Flush() pw.Close() diff -Nru golang-github-valyala-fasthttp-20160617/streaming.go golang-github-valyala-fasthttp-1.31.0/streaming.go --- golang-github-valyala-fasthttp-20160617/streaming.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/streaming.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,109 @@ +package fasthttp + +import ( + "bufio" + "bytes" + "io" + "sync" + + "github.com/valyala/bytebufferpool" +) + +type requestStream struct { + prefetchedBytes *bytes.Reader + reader *bufio.Reader + totalBytesRead int + contentLength int + chunkLeft int +} + +func (rs *requestStream) Read(p []byte) (int, error) { + var ( + n int + err error + ) + if rs.contentLength == -1 { + if rs.chunkLeft == 0 { + chunkSize, err := parseChunkSize(rs.reader) + if err != nil { + return 0, err + } + if chunkSize == 0 { + err = readCrLf(rs.reader) + if err == nil { + err = io.EOF + } + return 0, err + } + rs.chunkLeft = chunkSize + } + bytesToRead := len(p) + if rs.chunkLeft < len(p) { + bytesToRead = rs.chunkLeft + } + n, err = rs.reader.Read(p[:bytesToRead]) + rs.totalBytesRead += n + rs.chunkLeft -= n + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + if err == nil && rs.chunkLeft == 0 { + err = readCrLf(rs.reader) + } + return n, err + } + if rs.totalBytesRead == rs.contentLength { + return 0, io.EOF + } + prefetchedSize := int(rs.prefetchedBytes.Size()) + if prefetchedSize > rs.totalBytesRead { + left := prefetchedSize - rs.totalBytesRead + if len(p) > left { + p = p[:left] + } + n, err := rs.prefetchedBytes.Read(p) + rs.totalBytesRead += n + if n == rs.contentLength { + return n, io.EOF + } + return n, err + } else { + left := rs.contentLength - rs.totalBytesRead + if len(p) > left { + p = p[:left] + } + n, err = rs.reader.Read(p) + rs.totalBytesRead += n + if err != nil { + return n, err + } + } + + if rs.totalBytesRead == rs.contentLength { + err = io.EOF + } + return n, err +} + +func acquireRequestStream(b *bytebufferpool.ByteBuffer, r *bufio.Reader, contentLength int) *requestStream { + rs := requestStreamPool.Get().(*requestStream) + rs.prefetchedBytes = bytes.NewReader(b.B) + rs.reader = r + rs.contentLength = contentLength + + return rs +} + +func releaseRequestStream(rs *requestStream) { + rs.prefetchedBytes = nil + rs.totalBytesRead = 0 + rs.chunkLeft = 0 + rs.reader = nil + requestStreamPool.Put(rs) +} + +var requestStreamPool = sync.Pool{ + New: func() interface{} { + return &requestStream{} + }, +} diff -Nru golang-github-valyala-fasthttp-20160617/streaming_test.go golang-github-valyala-fasthttp-1.31.0/streaming_test.go --- golang-github-valyala-fasthttp-20160617/streaming_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/streaming_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,193 @@ +package fasthttp + +import ( + "bufio" + "bytes" + "io/ioutil" + "sync" + "testing" + "time" + + "github.com/valyala/fasthttp/fasthttputil" +) + +func TestStreamingPipeline(t *testing.T) { + t.Parallel() + + reqS := `POST /one HTTP/1.1 +Host: example.com +Content-Length: 10 + +aaaaaaaaaa +POST /two HTTP/1.1 +Host: example.com +Content-Length: 10 + +aaaaaaaaaa` + + ln := fasthttputil.NewInmemoryListener() + + s := &Server{ + StreamRequestBody: true, + Handler: func(ctx *RequestCtx) { + body := "" + expected := "aaaaaaaaaa" + if string(ctx.Path()) == "/one" { + body = string(ctx.PostBody()) + } else { + all, err := ioutil.ReadAll(ctx.RequestBodyStream()) + if err != nil { + t.Error(err) + } + body = string(all) + } + if body != expected { + t.Errorf("expected %q got %q", expected, body) + } + }, + } + + ch := make(chan struct{}) + go func() { + if err := s.Serve(ln); err != nil { + t.Errorf("unexpected error: %s", err) + } + close(ch) + }() + + conn, err := ln.Dial() + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + if _, err = conn.Write([]byte(reqS)); err != nil { + t.Fatalf("unexpected error: %s", err) + } + + var resp Response + br := bufio.NewReader(conn) + respCh := make(chan struct{}) + go func() { + if err := resp.Read(br); err != nil { + t.Errorf("error when reading response: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK) + } + + if err := resp.Read(br); err != nil { + t.Errorf("error when reading response: %s", err) + } + if resp.StatusCode() != StatusOK { + t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK) + } + close(respCh) + }() + + select { + case <-respCh: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + if err := ln.Close(); err != nil { + t.Fatalf("error when closing listener: %s", err) + } + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout when waiting for the server to stop") + } +} + +func getChunkedTestEnv(t testing.TB) (*fasthttputil.InmemoryListener, []byte) { + body := createFixedBody(128 * 1024) + chunkedBody := createChunkedBody(body) + + testHandler := func(ctx *RequestCtx) { + bodyBytes, err := ioutil.ReadAll(ctx.RequestBodyStream()) + if err != nil { + t.Logf("ioutil read returned err=%s", err) + t.Error("unexpected error while reading request body stream") + } + + if !bytes.Equal(body, bodyBytes) { + t.Errorf("unexpected request body, expected %q, got %q", body, bodyBytes) + } + } + s := &Server{ + Handler: testHandler, + StreamRequestBody: true, + MaxRequestBodySize: 1, // easier to test with small limit + } + + ln := fasthttputil.NewInmemoryListener() + + go func() { + err := s.Serve(ln) + if err != nil { + t.Errorf("could not serve listener: %s", err) + } + }() + + req := Request{} + req.SetHost("localhost") + req.Header.SetMethod("POST") + req.Header.Set("transfer-encoding", "chunked") + req.Header.SetContentLength(-1) + + formattedRequest := req.Header.Header() + formattedRequest = append(formattedRequest, chunkedBody...) + + return ln, formattedRequest +} + +func TestRequestStream(t *testing.T) { + t.Parallel() + + ln, formattedRequest := getChunkedTestEnv(t) + + c, err := ln.Dial() + if err != nil { + t.Errorf("unexpected error while dialing: %s", err) + } + if _, err = c.Write(formattedRequest); err != nil { + t.Errorf("unexpected error while writing request: %s", err) + } + + br := bufio.NewReader(c) + var respH ResponseHeader + if err = respH.Read(br); err != nil { + t.Errorf("unexpected error: %s", err) + } +} + +func BenchmarkRequestStreamE2E(b *testing.B) { + ln, formattedRequest := getChunkedTestEnv(b) + + wg := &sync.WaitGroup{} + wg.Add(4) + for i := 0; i < 4; i++ { + go func(wg *sync.WaitGroup) { + for i := 0; i < b.N/4; i++ { + c, err := ln.Dial() + if err != nil { + b.Errorf("unexpected error while dialing: %s", err) + } + if _, err = c.Write(formattedRequest); err != nil { + b.Errorf("unexpected error while writing request: %s", err) + } + + br := bufio.NewReaderSize(c, 128) + var respH ResponseHeader + if err = respH.Read(br); err != nil { + b.Errorf("unexpected error: %s", err) + } + c.Close() + } + wg.Done() + }(wg) + } + + wg.Wait() +} diff -Nru golang-github-valyala-fasthttp-20160617/stream_test.go golang-github-valyala-fasthttp-1.31.0/stream_test.go --- golang-github-valyala-fasthttp-20160617/stream_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/stream_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -10,6 +10,8 @@ ) func TestNewStreamReader(t *testing.T) { + t.Parallel() + ch := make(chan struct{}) r := NewStreamReader(func(w *bufio.Writer) { fmt.Fprintf(w, "Hello, world\n") @@ -38,6 +40,8 @@ } func TestStreamReaderClose(t *testing.T) { + t.Parallel() + firstLine := "the first line must pass" ch := make(chan error, 1) r := NewStreamReader(func(w *bufio.Writer) { @@ -49,7 +53,7 @@ data := createFixedBody(4000) for i := 0; i < 100; i++ { - w.Write(data) + w.Write(data) //nolint:errcheck } if err := w.Flush(); err == nil { ch <- fmt.Errorf("expecting error on the second flush") diff -Nru golang-github-valyala-fasthttp-20160617/strings.go golang-github-valyala-fasthttp-1.31.0/strings.go --- golang-github-valyala-fasthttp-20160617/strings.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/strings.go 2021-10-09 18:39:05.000000000 +0000 @@ -15,57 +15,68 @@ strCRLF = []byte("\r\n") strHTTP = []byte("http") strHTTPS = []byte("https") + strHTTP10 = []byte("HTTP/1.0") strHTTP11 = []byte("HTTP/1.1") + strColon = []byte(":") strColonSlashSlash = []byte("://") strColonSpace = []byte(": ") strGMT = []byte("GMT") strResponseContinue = []byte("HTTP/1.1 100 Continue\r\n\r\n") - strGet = []byte("GET") - strHead = []byte("HEAD") - strPost = []byte("POST") - strPut = []byte("PUT") - strDelete = []byte("DELETE") + strExpect = []byte(HeaderExpect) + strConnection = []byte(HeaderConnection) + strContentLength = []byte(HeaderContentLength) + strContentType = []byte(HeaderContentType) + strDate = []byte(HeaderDate) + strHost = []byte(HeaderHost) + strReferer = []byte(HeaderReferer) + strServer = []byte(HeaderServer) + strTransferEncoding = []byte(HeaderTransferEncoding) + strContentEncoding = []byte(HeaderContentEncoding) + strAcceptEncoding = []byte(HeaderAcceptEncoding) + strUserAgent = []byte(HeaderUserAgent) + strCookie = []byte(HeaderCookie) + strSetCookie = []byte(HeaderSetCookie) + strLocation = []byte(HeaderLocation) + strIfModifiedSince = []byte(HeaderIfModifiedSince) + strLastModified = []byte(HeaderLastModified) + strAcceptRanges = []byte(HeaderAcceptRanges) + strRange = []byte(HeaderRange) + strContentRange = []byte(HeaderContentRange) + strAuthorization = []byte(HeaderAuthorization) - strExpect = []byte("Expect") - strConnection = []byte("Connection") - strContentLength = []byte("Content-Length") - strContentType = []byte("Content-Type") - strDate = []byte("Date") - strHost = []byte("Host") - strReferer = []byte("Referer") - strServer = []byte("Server") - strTransferEncoding = []byte("Transfer-Encoding") - strContentEncoding = []byte("Content-Encoding") - strAcceptEncoding = []byte("Accept-Encoding") - strUserAgent = []byte("User-Agent") - strCookie = []byte("Cookie") - strSetCookie = []byte("Set-Cookie") - strLocation = []byte("Location") - strIfModifiedSince = []byte("If-Modified-Since") - strLastModified = []byte("Last-Modified") - strAcceptRanges = []byte("Accept-Ranges") - strRange = []byte("Range") - strContentRange = []byte("Content-Range") - - strCookieExpires = []byte("expires") - strCookieDomain = []byte("domain") - strCookiePath = []byte("path") - strCookieHTTPOnly = []byte("HttpOnly") - strCookieSecure = []byte("secure") + strCookieExpires = []byte("expires") + strCookieDomain = []byte("domain") + strCookiePath = []byte("path") + strCookieHTTPOnly = []byte("HttpOnly") + strCookieSecure = []byte("secure") + strCookieMaxAge = []byte("max-age") + strCookieSameSite = []byte("SameSite") + strCookieSameSiteLax = []byte("Lax") + strCookieSameSiteStrict = []byte("Strict") + strCookieSameSiteNone = []byte("None") strClose = []byte("close") strGzip = []byte("gzip") + strBr = []byte("br") strDeflate = []byte("deflate") strKeepAlive = []byte("keep-alive") - strKeepAliveCamelCase = []byte("Keep-Alive") strUpgrade = []byte("Upgrade") strChunked = []byte("chunked") strIdentity = []byte("identity") str100Continue = []byte("100-continue") strPostArgsContentType = []byte("application/x-www-form-urlencoded") + strDefaultContentType = []byte("application/octet-stream") strMultipartFormData = []byte("multipart/form-data") strBoundary = []byte("boundary") strBytes = []byte("bytes") + strBasicSpace = []byte("Basic ") + + strApplicationSlash = []byte("application/") + strImageSVG = []byte("image/svg") + strImageIcon = []byte("image/x-icon") + strFontSlash = []byte("font/") + strMultipartSlash = []byte("multipart/") + strTextSlash = []byte("text/") ) diff -Nru golang-github-valyala-fasthttp-20160617/tcpdialer.go golang-github-valyala-fasthttp-1.31.0/tcpdialer.go --- golang-github-valyala-fasthttp-20160617/tcpdialer.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/tcpdialer.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,6 +1,7 @@ package fasthttp import ( + "context" "errors" "net" "strconv" @@ -14,7 +15,7 @@ // This function has the following additional features comparing to net.Dial: // // * It reduces load on DNS resolver by caching resolved TCP addressed -// for DefaultDNSCacheDuration. +// for DNSCacheDuration. // * It dials all the resolved TCP addresses in round-robin manner until // connection is established. This may be useful if certain addresses // are temporarily unreachable. @@ -33,7 +34,7 @@ // * foo.bar:80 // * aaa.com:8080 func Dial(addr string) (net.Conn, error) { - return getDialer(DefaultDialTimeout, false)(addr) + return defaultDialer.Dial(addr) } // DialTimeout dials the given TCP addr using tcp4 using the given timeout. @@ -41,7 +42,7 @@ // This function has the following additional features comparing to net.Dial: // // * It reduces load on DNS resolver by caching resolved TCP addressed -// for DefaultDNSCacheDuration. +// for DNSCacheDuration. // * It dials all the resolved TCP addresses in round-robin manner until // connection is established. This may be useful if certain addresses // are temporarily unreachable. @@ -58,7 +59,7 @@ // * foo.bar:80 // * aaa.com:8080 func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { - return getDialer(timeout, false)(addr) + return defaultDialer.DialTimeout(addr, timeout) } // DialDualStack dials the given TCP addr using both tcp4 and tcp6. @@ -66,7 +67,7 @@ // This function has the following additional features comparing to net.Dial: // // * It reduces load on DNS resolver by caching resolved TCP addressed -// for DefaultDNSCacheDuration. +// for DNSCacheDuration. // * It dials all the resolved TCP addresses in round-robin manner until // connection is established. This may be useful if certain addresses // are temporarily unreachable. @@ -86,7 +87,7 @@ // * foo.bar:80 // * aaa.com:8080 func DialDualStack(addr string) (net.Conn, error) { - return getDialer(DefaultDialTimeout, true)(addr) + return defaultDialer.DialDualStack(addr) } // DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6 @@ -95,7 +96,7 @@ // This function has the following additional features comparing to net.Dial: // // * It reduces load on DNS resolver by caching resolved TCP addressed -// for DefaultDNSCacheDuration. +// for DNSCacheDuration. // * It dials all the resolved TCP addresses in round-robin manner until // connection is established. This may be useful if certain addresses // are temporarily unreachable. @@ -112,157 +113,240 @@ // * foo.bar:80 // * aaa.com:8080 func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) { - return getDialer(timeout, true)(addr) -} - -func getDialer(timeout time.Duration, dualStack bool) DialFunc { - if timeout <= 0 { - timeout = DefaultDialTimeout - } - timeoutRounded := int(timeout.Seconds()*10 + 9) - - m := dialMap - if dualStack { - m = dialDualStackMap - } - - dialMapLock.Lock() - d := m[timeoutRounded] - if d == nil { - dialer := dialerStd - if dualStack { - dialer = dialerDualStack - } - d = dialer.NewDial(timeout) - m[timeoutRounded] = d - } - dialMapLock.Unlock() - return d + return defaultDialer.DialDualStackTimeout(addr, timeout) } var ( - dialerStd = &tcpDialer{} - dialerDualStack = &tcpDialer{DualStack: true} - - dialMap = make(map[int]DialFunc) - dialDualStackMap = make(map[int]DialFunc) - dialMapLock sync.Mutex + defaultDialer = &TCPDialer{Concurrency: 1000} ) -type tcpDialer struct { - DualStack bool +// Resolver represents interface of the tcp resolver. +type Resolver interface { + LookupIPAddr(context.Context, string) (names []net.IPAddr, err error) +} + +// TCPDialer contains options to control a group of Dial calls. +type TCPDialer struct { + // Concurrency controls the maximum number of concurrent Dails + // that can be performed using this object. + // Setting this to 0 means unlimited. + // + // WARNING: This can only be changed before the first Dial. + // Changes made after the first Dial will not affect anything. + Concurrency int + + // LocalAddr is the local address to use when dialing an + // address. + // If nil, a local address is automatically chosen. + LocalAddr *net.TCPAddr + + // This may be used to override DNS resolving policy, like this: + // var dialer = &fasthttp.TCPDialer{ + // Resolver: &net.Resolver{ + // PreferGo: true, + // StrictErrors: false, + // Dial: func (ctx context.Context, network, address string) (net.Conn, error) { + // d := net.Dialer{} + // return d.DialContext(ctx, "udp", "8.8.8.8:53") + // }, + // }, + // } + Resolver Resolver + + // DNSCacheDuration may be used to override the default DNS cache duration (DefaultDNSCacheDuration) + DNSCacheDuration time.Duration - tcpAddrsLock sync.Mutex - tcpAddrsMap map[string]*tcpAddrEntry + tcpAddrsMap sync.Map concurrencyCh chan struct{} once sync.Once } -const maxDialConcurrency = 1000 +// Dial dials the given TCP addr using tcp4. +// +// This function has the following additional features comparing to net.Dial: +// +// * It reduces load on DNS resolver by caching resolved TCP addressed +// for DNSCacheDuration. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. +// * It returns ErrDialTimeout if connection cannot be established during +// DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout. +// +// This dialer is intended for custom code wrapping before passing +// to Client.Dial or HostClient.Dial. +// +// For instance, per-host counters and/or limits may be implemented +// by such wrappers. +// +// The addr passed to the function must contain port. Example addr values: +// +// * foobar.baz:443 +// * foo.bar:80 +// * aaa.com:8080 +func (d *TCPDialer) Dial(addr string) (net.Conn, error) { + return d.dial(addr, false, DefaultDialTimeout) +} + +// DialTimeout dials the given TCP addr using tcp4 using the given timeout. +// +// This function has the following additional features comparing to net.Dial: +// +// * It reduces load on DNS resolver by caching resolved TCP addressed +// for DNSCacheDuration. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. +// +// This dialer is intended for custom code wrapping before passing +// to Client.Dial or HostClient.Dial. +// +// For instance, per-host counters and/or limits may be implemented +// by such wrappers. +// +// The addr passed to the function must contain port. Example addr values: +// +// * foobar.baz:443 +// * foo.bar:80 +// * aaa.com:8080 +func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { + return d.dial(addr, false, timeout) +} -func (d *tcpDialer) NewDial(timeout time.Duration) DialFunc { +// DialDualStack dials the given TCP addr using both tcp4 and tcp6. +// +// This function has the following additional features comparing to net.Dial: +// +// * It reduces load on DNS resolver by caching resolved TCP addressed +// for DNSCacheDuration. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. +// * It returns ErrDialTimeout if connection cannot be established during +// DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial +// timeout. +// +// This dialer is intended for custom code wrapping before passing +// to Client.Dial or HostClient.Dial. +// +// For instance, per-host counters and/or limits may be implemented +// by such wrappers. +// +// The addr passed to the function must contain port. Example addr values: +// +// * foobar.baz:443 +// * foo.bar:80 +// * aaa.com:8080 +func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) { + return d.dial(addr, true, DefaultDialTimeout) +} + +// DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6 +// using the given timeout. +// +// This function has the following additional features comparing to net.Dial: +// +// * It reduces load on DNS resolver by caching resolved TCP addressed +// for DNSCacheDuration. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. +// +// This dialer is intended for custom code wrapping before passing +// to Client.Dial or HostClient.Dial. +// +// For instance, per-host counters and/or limits may be implemented +// by such wrappers. +// +// The addr passed to the function must contain port. Example addr values: +// +// * foobar.baz:443 +// * foo.bar:80 +// * aaa.com:8080 +func (d *TCPDialer) DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) { + return d.dial(addr, true, timeout) +} + +func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (net.Conn, error) { d.once.Do(func() { - d.concurrencyCh = make(chan struct{}, maxDialConcurrency) - d.tcpAddrsMap = make(map[string]*tcpAddrEntry) + if d.Concurrency > 0 { + d.concurrencyCh = make(chan struct{}, d.Concurrency) + } + + if d.DNSCacheDuration == 0 { + d.DNSCacheDuration = DefaultDNSCacheDuration + } + go d.tcpAddrsClean() }) - return func(addr string) (net.Conn, error) { - addrs, idx, err := d.getTCPAddrs(addr) - if err != nil { - return nil, err - } - network := "tcp4" - if d.DualStack { - network = "tcp" - } + addrs, idx, err := d.getTCPAddrs(addr, dualStack) + if err != nil { + return nil, err + } + network := "tcp4" + if dualStack { + network = "tcp" + } - var conn net.Conn - n := uint32(len(addrs)) - deadline := time.Now().Add(timeout) - for n > 0 { - conn, err = tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh) - if err == nil { - return conn, nil - } - if err == ErrDialTimeout { - return nil, err - } - idx++ - n-- + var conn net.Conn + n := uint32(len(addrs)) + deadline := time.Now().Add(timeout) + for n > 0 { + conn, err = d.tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh) + if err == nil { + return conn, nil } - return nil, err + if err == ErrDialTimeout { + return nil, err + } + idx++ + n-- } + return nil, err } -func tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}) (net.Conn, error) { +func (d *TCPDialer) tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}) (net.Conn, error) { timeout := -time.Since(deadline) if timeout <= 0 { return nil, ErrDialTimeout } - select { - case concurrencyCh <- struct{}{}: - default: - tc := acquireTimer(timeout) - isTimeout := false + if concurrencyCh != nil { select { case concurrencyCh <- struct{}{}: - case <-tc.C: - isTimeout = true - } - releaseTimer(tc) - if isTimeout { - return nil, ErrDialTimeout + default: + tc := AcquireTimer(timeout) + isTimeout := false + select { + case concurrencyCh <- struct{}{}: + case <-tc.C: + isTimeout = true + } + ReleaseTimer(tc) + if isTimeout { + return nil, ErrDialTimeout + } } + defer func() { <-concurrencyCh }() } - timeout = -time.Since(deadline) - if timeout <= 0 { - <-concurrencyCh - return nil, ErrDialTimeout + dialer := net.Dialer{} + if d.LocalAddr != nil { + dialer.LocalAddr = d.LocalAddr } - chv := dialResultChanPool.Get() - if chv == nil { - chv = make(chan dialResult, 1) - } - ch := chv.(chan dialResult) - go func() { - var dr dialResult - dr.conn, dr.err = net.DialTCP(network, nil, addr) - ch <- dr - <-concurrencyCh - }() - - var ( - conn net.Conn - err error - ) - - tc := acquireTimer(timeout) - select { - case dr := <-ch: - conn = dr.conn - err = dr.err - dialResultChanPool.Put(ch) - case <-tc.C: - err = ErrDialTimeout + ctx, cancel_ctx := context.WithDeadline(context.Background(), deadline) + defer cancel_ctx() + conn, err := dialer.DialContext(ctx, network, addr.String()) + if err != nil && ctx.Err() == context.DeadlineExceeded { + return nil, ErrDialTimeout } - releaseTimer(tc) - return conn, err } -var dialResultChanPool sync.Pool - -type dialResult struct { - conn net.Conn - err error -} - // ErrDialTimeout is returned when TCP dialing is timed out. var ErrDialTimeout = errors.New("dialing to the given TCP address timed out") @@ -282,40 +366,37 @@ // by Dial* functions. const DefaultDNSCacheDuration = time.Minute -func (d *tcpDialer) tcpAddrsClean() { - expireDuration := 2 * DefaultDNSCacheDuration +func (d *TCPDialer) tcpAddrsClean() { + expireDuration := 2 * d.DNSCacheDuration for { time.Sleep(time.Second) t := time.Now() - - d.tcpAddrsLock.Lock() - for k, e := range d.tcpAddrsMap { - if t.Sub(e.resolveTime) > expireDuration { - delete(d.tcpAddrsMap, k) + d.tcpAddrsMap.Range(func(k, v interface{}) bool { + if e, ok := v.(*tcpAddrEntry); ok && t.Sub(e.resolveTime) > expireDuration { + d.tcpAddrsMap.Delete(k) } - } - d.tcpAddrsLock.Unlock() + return true + }) + } } -func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) { - d.tcpAddrsLock.Lock() - e := d.tcpAddrsMap[addr] - if e != nil && !e.pending && time.Since(e.resolveTime) > DefaultDNSCacheDuration { +func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uint32, error) { + item, exist := d.tcpAddrsMap.Load(addr) + e, ok := item.(*tcpAddrEntry) + if exist && ok && e != nil && !e.pending && time.Since(e.resolveTime) > d.DNSCacheDuration { e.pending = true e = nil } - d.tcpAddrsLock.Unlock() if e == nil { - addrs, err := resolveTCPAddrs(addr, d.DualStack) + addrs, err := resolveTCPAddrs(addr, dualStack, d.Resolver) if err != nil { - d.tcpAddrsLock.Lock() - e = d.tcpAddrsMap[addr] - if e != nil && e.pending { + item, exist := d.tcpAddrsMap.Load(addr) + e, ok = item.(*tcpAddrEntry) + if exist && ok && e != nil && e.pending { e.pending = false } - d.tcpAddrsLock.Unlock() return nil, 0, err } @@ -323,20 +404,14 @@ addrs: addrs, resolveTime: time.Now(), } - - d.tcpAddrsLock.Lock() - d.tcpAddrsMap[addr] = e - d.tcpAddrsLock.Unlock() + d.tcpAddrsMap.Store(addr, e) } - idx := uint32(0) - if len(e.addrs) > 0 { - idx = atomic.AddUint32(&e.addrsIdx, 1) - } + idx := atomic.AddUint32(&e.addrsIdx, 1) return e.addrs, idx, nil } -func resolveTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, error) { +func resolveTCPAddrs(addr string, dualStack bool, resolver Resolver) ([]net.TCPAddr, error) { host, portS, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -346,22 +421,33 @@ return nil, err } - ips, err := net.LookupIP(host) + if resolver == nil { + resolver = net.DefaultResolver + } + + ctx := context.Background() + ipaddrs, err := resolver.LookupIPAddr(ctx, host) if err != nil { return nil, err } - n := len(ips) + n := len(ipaddrs) addrs := make([]net.TCPAddr, 0, n) for i := 0; i < n; i++ { - ip := ips[i] - if !dualStack && ip.To4() == nil { + ip := ipaddrs[i] + if !dualStack && ip.IP.To4() == nil { continue } addrs = append(addrs, net.TCPAddr{ - IP: ip, + IP: ip.IP, Port: port, + Zone: ip.Zone, }) } + if len(addrs) == 0 { + return nil, errNoDNSEntries + } return addrs, nil } + +var errNoDNSEntries = errors.New("couldn't find DNS entries for the given domain. Try using DialDualStack") diff -Nru golang-github-valyala-fasthttp-20160617/testdata/test.png golang-github-valyala-fasthttp-1.31.0/testdata/test.png --- golang-github-valyala-fasthttp-20160617/testdata/test.png 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/testdata/test.png 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1 @@ + diff -Nru golang-github-valyala-fasthttp-20160617/timer.go golang-github-valyala-fasthttp-1.31.0/timer.go --- golang-github-valyala-fasthttp-20160617/timer.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/timer.go 2021-10-09 18:39:05.000000000 +0000 @@ -26,7 +26,12 @@ } } -func acquireTimer(timeout time.Duration) *time.Timer { +// AcquireTimer returns a time.Timer from the pool and updates it to +// send the current time on its channel after at least timeout. +// +// The returned Timer may be returned to the pool with ReleaseTimer +// when no longer needed. This allows reducing GC load. +func AcquireTimer(timeout time.Duration) *time.Timer { v := timerPool.Get() if v == nil { return time.NewTimer(timeout) @@ -36,7 +41,12 @@ return t } -func releaseTimer(t *time.Timer) { +// ReleaseTimer returns the time.Timer acquired via AcquireTimer to the pool +// and prevents the Timer from firing. +// +// Do not access the released time.Timer or read from it's channel otherwise +// data races may occur. +func ReleaseTimer(t *time.Timer) { stopTimer(t) timerPool.Put(t) } diff -Nru golang-github-valyala-fasthttp-20160617/tls.go golang-github-valyala-fasthttp-1.31.0/tls.go --- golang-github-valyala-fasthttp-20160617/tls.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/tls.go 2021-10-09 18:39:05.000000000 +0000 @@ -0,0 +1,60 @@ +package fasthttp + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "time" +) + +// GenerateTestCertificate generates a test certificate and private key based on the given host. +func GenerateTestCertificate(host string) ([]byte, []byte, error) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, nil, err + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, err + } + + cert := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"fasthttp test"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + SignatureAlgorithm: x509.SHA256WithRSA, + DNSNames: []string{host}, + BasicConstraintsValid: true, + IsCA: true, + } + + certBytes, err := x509.CreateCertificate( + rand.Reader, cert, cert, &priv.PublicKey, priv, + ) + + p := pem.EncodeToMemory( + &pem.Block{ + Type: "PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(priv), + }, + ) + + b := pem.EncodeToMemory( + &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }, + ) + + return b, p, err +} diff -Nru golang-github-valyala-fasthttp-20160617/.travis.yml golang-github-valyala-fasthttp-1.31.0/.travis.yml --- golang-github-valyala-fasthttp-20160617/.travis.yml 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/.travis.yml 1970-01-01 00:00:00.000000000 +0000 @@ -1,15 +0,0 @@ -language: go - -go: - - 1.6 - -script: - # build test for supported platforms - - GOOS=linux go build - - GOOS=darwin go build - - GOOS=freebsd go build - - GOOS=windows go build - - GOARCH=386 go build - - # run tests on a standard platform - - go test -v ./... diff -Nru golang-github-valyala-fasthttp-20160617/uri.go golang-github-valyala-fasthttp-1.31.0/uri.go --- golang-github-valyala-fasthttp-20160617/uri.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/uri.go 2021-10-09 18:39:05.000000000 +0000 @@ -2,7 +2,10 @@ import ( "bytes" + "errors" + "fmt" "io" + "strconv" "sync" ) @@ -36,7 +39,7 @@ // // URI instance MUST NOT be used from concurrently running goroutines. type URI struct { - noCopy noCopy + noCopy noCopy //nolint:unused,structcheck pathOriginal []byte scheme []byte @@ -48,10 +51,20 @@ queryArgs Args parsedQueryArgs bool + // Path values are sent as-is without normalization + // + // Disabled path normalization may be useful for proxying incoming requests + // to servers that are expecting paths to be forwarded as-is. + // + // By default path values are normalized, i.e. + // extra slashes are removed, special characters are encoded. + DisablePathNormalizing bool + fullURI []byte requestURI []byte - h *RequestHeader + username []byte + password []byte } // CopyTo copies uri contents to dst. @@ -63,18 +76,20 @@ dst.queryString = append(dst.queryString[:0], u.queryString...) dst.hash = append(dst.hash[:0], u.hash...) dst.host = append(dst.host[:0], u.host...) + dst.username = append(dst.username[:0], u.username...) + dst.password = append(dst.password[:0], u.password...) u.queryArgs.CopyTo(&dst.queryArgs) dst.parsedQueryArgs = u.parsedQueryArgs + dst.DisablePathNormalizing = u.DisablePathNormalizing // fullURI and requestURI shouldn't be copied, since they are created // from scratch on each FullURI() and RequestURI() call. - dst.h = u.h } // Hash returns URI hash, i.e. qwe of http://aaa.com/foo/bar?baz=123#qwe . // -// The returned value is valid until the next URI method call. +// The returned bytes are valid until the next URI method call. func (u *URI) Hash() []byte { return u.hash } @@ -89,10 +104,44 @@ u.hash = append(u.hash[:0], hash...) } +// Username returns URI username +// +// The returned bytes are valid until the next URI method call. +func (u *URI) Username() []byte { + return u.username +} + +// SetUsername sets URI username. +func (u *URI) SetUsername(username string) { + u.username = append(u.username[:0], username...) +} + +// SetUsernameBytes sets URI username. +func (u *URI) SetUsernameBytes(username []byte) { + u.username = append(u.username[:0], username...) +} + +// Password returns URI password +// +// The returned bytes are valid until the next URI method call. +func (u *URI) Password() []byte { + return u.password +} + +// SetPassword sets URI password. +func (u *URI) SetPassword(password string) { + u.password = append(u.password[:0], password...) +} + +// SetPasswordBytes sets URI password. +func (u *URI) SetPasswordBytes(password []byte) { + u.password = append(u.password[:0], password...) +} + // QueryString returns URI query string, // i.e. baz=123 of http://aaa.com/foo/bar?baz=123#qwe . // -// The returned value is valid until the next URI method call. +// The returned bytes are valid until the next URI method call. func (u *URI) QueryString() []byte { return u.queryString } @@ -114,7 +163,7 @@ // The returned path is always urldecoded and normalized, // i.e. '//f%20obar/baz/../zzz' becomes '/f obar/zzz'. // -// The returned value is valid until the next URI method call. +// The returned bytes are valid until the next URI method call. func (u *URI) Path() []byte { path := u.path if len(path) == 0 { @@ -137,7 +186,7 @@ // PathOriginal returns the original path from requestURI passed to URI.Parse(). // -// The returned value is valid until the next URI method call. +// The returned bytes are valid until the next URI method call. func (u *URI) PathOriginal() []byte { return u.pathOriginal } @@ -146,7 +195,7 @@ // // Returned scheme is always lowercased. // -// The returned value is valid until the next URI method call. +// The returned bytes are valid until the next URI method call. func (u *URI) Scheme() []byte { scheme := u.scheme if len(scheme) == 0 { @@ -174,29 +223,27 @@ u.path = u.path[:0] u.queryString = u.queryString[:0] u.hash = u.hash[:0] + u.username = u.username[:0] + u.password = u.password[:0] u.host = u.host[:0] u.queryArgs.Reset() u.parsedQueryArgs = false + u.DisablePathNormalizing = false // There is no need in u.fullURI = u.fullURI[:0], since full uri - // is calucalted on each call to FullURI(). + // is calculated on each call to FullURI(). // There is no need in u.requestURI = u.requestURI[:0], since requestURI // is calculated on each call to RequestURI(). - - u.h = nil } // Host returns host part, i.e. aaa.com of http://aaa.com/foo/bar?baz=123#qwe . // // Host is always lowercased. +// +// The returned bytes are valid until the next URI method call. func (u *URI) Host() []byte { - if len(u.host) == 0 && u.h != nil { - u.host = append(u.host[:0], u.h.Host()...) - lowercaseBytes(u.host) - u.h = nil - } return u.host } @@ -212,23 +259,58 @@ lowercaseBytes(u.host) } -// Parse initializes URI from the given host and uri. -func (u *URI) Parse(host, uri []byte) { - u.parse(host, uri, nil) -} +var ( + ErrorInvalidURI = errors.New("invalid uri") +) -func (u *URI) parseQuick(uri []byte, h *RequestHeader) { - u.parse(nil, uri, h) +// Parse initializes URI from the given host and uri. +// +// host may be nil. In this case uri must contain fully qualified uri, +// i.e. with scheme and host. http is assumed if scheme is omitted. +// +// uri may contain e.g. RequestURI without scheme and host if host is non-empty. +func (u *URI) Parse(host, uri []byte) error { + return u.parse(host, uri, false) } -func (u *URI) parse(host, uri []byte, h *RequestHeader) { +func (u *URI) parse(host, uri []byte, isTLS bool) error { u.Reset() - u.h = h - scheme, host, uri := splitHostURI(host, uri) - u.scheme = append(u.scheme, scheme...) - lowercaseBytes(u.scheme) + if stringContainsCTLByte(uri) { + return ErrorInvalidURI + } + + if len(host) == 0 || bytes.Contains(uri, strColonSlashSlash) { + scheme, newHost, newURI := splitHostURI(host, uri) + u.scheme = append(u.scheme, scheme...) + lowercaseBytes(u.scheme) + host = newHost + uri = newURI + } + + if isTLS { + u.scheme = append(u.scheme[:0], strHTTPS...) + } + + if n := bytes.IndexByte(host, '@'); n >= 0 { + auth := host[:n] + host = host[n+1:] + + if n := bytes.IndexByte(auth, ':'); n >= 0 { + u.username = append(u.username[:0], auth[:n]...) + u.password = append(u.password[:0], auth[n+1:]...) + } else { + u.username = append(u.username[:0], auth...) + u.password = u.password[:0] + } + } + u.host = append(u.host, host...) + if parsedHost, err := parseHost(u.host); err != nil { + return err + } else { + u.host = parsedHost + } lowercaseBytes(u.host) b := uri @@ -242,7 +324,7 @@ if queryIndex < 0 && fragmentIndex < 0 { u.pathOriginal = append(u.pathOriginal, b...) u.path = normalizePath(u.path, u.pathOriginal) - return + return nil } if queryIndex >= 0 { @@ -256,7 +338,7 @@ u.queryString = append(u.queryString, b[queryIndex+1:fragmentIndex]...) u.hash = append(u.hash, b[fragmentIndex+1:]...) } - return + return nil } // fragmentIndex >= 0 && queryIndex < 0 @@ -264,12 +346,234 @@ u.pathOriginal = append(u.pathOriginal, b[:fragmentIndex]...) u.path = normalizePath(u.path, u.pathOriginal) u.hash = append(u.hash, b[fragmentIndex+1:]...) + + return nil +} + +// parseHost parses host as an authority without user +// information. That is, as host[:port]. +// +// Based on https://github.com/golang/go/blob/8ac5cbe05d61df0a7a7c9a38ff33305d4dcfea32/src/net/url/url.go#L619 +// +// The host is parsed and unescaped in place overwriting the contents of the host parameter. +func parseHost(host []byte) ([]byte, error) { + if len(host) > 0 && host[0] == '[' { + // Parse an IP-Literal in RFC 3986 and RFC 6874. + // E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80". + i := bytes.LastIndexByte(host, ']') + if i < 0 { + return nil, errors.New("missing ']' in host") + } + colonPort := host[i+1:] + if !validOptionalPort(colonPort) { + return nil, fmt.Errorf("invalid port %q after host", colonPort) + } + + // RFC 6874 defines that %25 (%-encoded percent) introduces + // the zone identifier, and the zone identifier can use basically + // any %-encoding it likes. That's different from the host, which + // can only %-encode non-ASCII bytes. + // We do impose some restrictions on the zone, to avoid stupidity + // like newlines. + zone := bytes.Index(host[:i], []byte("%25")) + if zone >= 0 { + host1, err := unescape(host[:zone], encodeHost) + if err != nil { + return nil, err + } + host2, err := unescape(host[zone:i], encodeZone) + if err != nil { + return nil, err + } + host3, err := unescape(host[i:], encodeHost) + if err != nil { + return nil, err + } + return append(host1, append(host2, host3...)...), nil + } + } else if i := bytes.LastIndexByte(host, ':'); i != -1 { + colonPort := host[i:] + if !validOptionalPort(colonPort) { + return nil, fmt.Errorf("invalid port %q after host", colonPort) + } + } + + var err error + if host, err = unescape(host, encodeHost); err != nil { + return nil, err + } + return host, nil +} + +type encoding int + +const ( + encodeHost encoding = 1 + iota + encodeZone +) + +type EscapeError string + +func (e EscapeError) Error() string { + return "invalid URL escape " + strconv.Quote(string(e)) +} + +type InvalidHostError string + +func (e InvalidHostError) Error() string { + return "invalid character " + strconv.Quote(string(e)) + " in host name" +} + +// unescape unescapes a string; the mode specifies +// which section of the URL string is being unescaped. +// +// Based on https://github.com/golang/go/blob/8ac5cbe05d61df0a7a7c9a38ff33305d4dcfea32/src/net/url/url.go#L199 +// +// Unescapes in place overwriting the contents of s and returning it. +func unescape(s []byte, mode encoding) ([]byte, error) { + // Count %, check that they're well-formed. + n := 0 + for i := 0; i < len(s); { + switch s[i] { + case '%': + n++ + if i+2 >= len(s) || !ishex(s[i+1]) || !ishex(s[i+2]) { + s = s[i:] + if len(s) > 3 { + s = s[:3] + } + return nil, EscapeError(s) + } + // Per https://tools.ietf.org/html/rfc3986#page-21 + // in the host component %-encoding can only be used + // for non-ASCII bytes. + // But https://tools.ietf.org/html/rfc6874#section-2 + // introduces %25 being allowed to escape a percent sign + // in IPv6 scoped-address literals. Yay. + if mode == encodeHost && unhex(s[i+1]) < 8 && !bytes.Equal(s[i:i+3], []byte("%25")) { + return nil, EscapeError(s[i : i+3]) + } + if mode == encodeZone { + // RFC 6874 says basically "anything goes" for zone identifiers + // and that even non-ASCII can be redundantly escaped, + // but it seems prudent to restrict %-escaped bytes here to those + // that are valid host name bytes in their unescaped form. + // That is, you can use escaping in the zone identifier but not + // to introduce bytes you couldn't just write directly. + // But Windows puts spaces here! Yay. + v := unhex(s[i+1])<<4 | unhex(s[i+2]) + if !bytes.Equal(s[i:i+3], []byte("%25")) && v != ' ' && shouldEscape(v, encodeHost) { + return nil, EscapeError(s[i : i+3]) + } + } + i += 3 + default: + if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) { + return nil, InvalidHostError(s[i : i+1]) + } + i++ + } + } + + if n == 0 { + return s, nil + } + + t := s[:0] + for i := 0; i < len(s); i++ { + switch s[i] { + case '%': + t = append(t, unhex(s[i+1])<<4|unhex(s[i+2])) + i += 2 + default: + t = append(t, s[i]) + } + } + return t, nil +} + +// Return true if the specified character should be escaped when +// appearing in a URL string, according to RFC 3986. +// +// Please be informed that for now shouldEscape does not check all +// reserved characters correctly. See golang.org/issue/5684. +// +// Based on https://github.com/golang/go/blob/8ac5cbe05d61df0a7a7c9a38ff33305d4dcfea32/src/net/url/url.go#L100 +func shouldEscape(c byte, mode encoding) bool { + // §2.3 Unreserved characters (alphanum) + if 'a' <= c && c <= 'z' || 'A' <= c && c <= 'Z' || '0' <= c && c <= '9' { + return false + } + + if mode == encodeHost || mode == encodeZone { + // §3.2.2 Host allows + // sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "=" + // as part of reg-name. + // We add : because we include :port as part of host. + // We add [ ] because we include [ipv6]:port as part of host. + // We add < > because they're the only characters left that + // we could possibly allow, and Parse will reject them if we + // escape them (because hosts can't use %-encoding for + // ASCII bytes). + switch c { + case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"': + return false + } + } + + if c == '-' || c == '_' || c == '.' || c == '~' { // §2.3 Unreserved characters (mark) + return false + } + + // Everything else must be escaped. + return true +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + } + return false +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} + +// validOptionalPort reports whether port is either an empty string +// or matches /^:\d*$/ +func validOptionalPort(port []byte) bool { + if len(port) == 0 { + return true + } + if port[0] != ':' { + return false + } + for _, b := range port[1:] { + if b < '0' || b > '9' { + return false + } + } + return true } func normalizePath(dst, src []byte) []byte { dst = dst[:0] dst = addLeadingSlash(dst, src) - dst = decodeArgAppend(dst, src, false) + dst = decodeArgAppendNoPlus(dst, src) // remove duplicate slashes b := dst @@ -318,7 +622,7 @@ if n >= 0 && n+len(strSlashDotDot) == len(b) { nn := bytes.LastIndexByte(b[:n], '/') if nn < 0 { - return strSlash + return append(dst[:0], strSlash...) } b = b[:nn+1] } @@ -328,18 +632,19 @@ // RequestURI returns RequestURI - i.e. URI without Scheme and Host. func (u *URI) RequestURI() []byte { - dst := appendQuotedPath(u.requestURI[:0], u.Path()) - if u.queryArgs.Len() > 0 { + var dst []byte + if u.DisablePathNormalizing { + dst = append(u.requestURI[:0], u.PathOriginal()...) + } else { + dst = appendQuotedPath(u.requestURI[:0], u.Path()) + } + if u.parsedQueryArgs && u.queryArgs.Len() > 0 { dst = append(dst, '?') dst = u.queryArgs.AppendBytes(dst) } else if len(u.queryString) > 0 { dst = append(dst, '?') dst = append(dst, u.queryString...) } - if len(u.hash) > 0 { - dst = append(dst, '#') - dst = append(dst, u.hash...) - } u.requestURI = dst return u.requestURI } @@ -351,6 +656,8 @@ // * For /foo/bar/baz.html path returns baz.html. // * For /foo/bar/ returns empty byte slice. // * For /foobar.js returns foobar.js. +// +// The returned bytes are valid until the next URI method call. func (u *URI) LastPathSegment() []byte { path := u.Path() n := bytes.LastIndexByte(path, '/') @@ -366,6 +673,8 @@ // // * Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original // uri is replaced by newURI. +// * Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case +// the original scheme is preserved. // * Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part // of the original uri is replaced. // * Relative path, i.e. xx?yy=abc . In this case the original RequestURI @@ -380,6 +689,8 @@ // // * Absolute, i.e. http://foobar.com/aaa/bb?cc . In this case the original // uri is replaced by newURI. +// * Absolute without scheme, i.e. //foobar.com/aaa/bb?cc. In this case +// the original scheme is preserved. // * Missing host, i.e. /aaa/bb?cc . In this case only RequestURI part // of the original uri is replaced. // * Relative path, i.e. xx?yy=abc . In this case the original RequestURI @@ -392,18 +703,31 @@ if len(newURI) == 0 { return buf } + + n := bytes.Index(newURI, strSlashSlash) + if n >= 0 { + // absolute uri + var b [32]byte + schemeOriginal := b[:0] + if len(u.scheme) > 0 { + schemeOriginal = append([]byte(nil), u.scheme...) + } + if err := u.Parse(nil, newURI); err != nil { + return nil + } + if len(schemeOriginal) > 0 && len(u.scheme) == 0 { + u.scheme = append(u.scheme[:0], schemeOriginal...) + } + return buf + } + if newURI[0] == '/' { // uri without host buf = u.appendSchemeHost(buf[:0]) buf = append(buf, newURI...) - u.Parse(nil, buf) - return buf - } - - n := bytes.Index(newURI, strColonSlashSlash) - if n >= 0 { - // absolute uri - u.Parse(nil, newURI) + if err := u.Parse(nil, buf); err != nil { + return nil + } return buf } @@ -422,17 +746,21 @@ path := u.Path() n = bytes.LastIndexByte(path, '/') if n < 0 { - panic("BUG: path must contain at least one slash") + panic(fmt.Sprintf("BUG: path must contain at least one slash: %s %s", u.Path(), newURI)) } buf = u.appendSchemeHost(buf[:0]) buf = appendQuotedPath(buf, path[:n+1]) buf = append(buf, newURI...) - u.Parse(nil, buf) + if err := u.Parse(nil, buf); err != nil { + return nil + } return buf } } // FullURI returns full uri in the form {Scheme}://{Host}{RequestURI}#{Hash}. +// +// The returned bytes are valid until the next URI method call. func (u *URI) FullURI() []byte { u.fullURI = u.AppendBytes(u.fullURI[:0]) return u.fullURI @@ -441,7 +769,12 @@ // AppendBytes appends full uri to dst and returns the extended dst. func (u *URI) AppendBytes(dst []byte) []byte { dst = u.appendSchemeHost(dst) - return append(dst, u.RequestURI()...) + dst = append(dst, u.RequestURI()...) + if len(u.hash) > 0 { + dst = append(dst, '#') + dst = append(dst, u.hash...) + } + return dst } func (u *URI) appendSchemeHost(dst []byte) []byte { @@ -464,7 +797,7 @@ } func splitHostURI(host, uri []byte) ([]byte, []byte, []byte) { - n := bytes.Index(uri, strColonSlashSlash) + n := bytes.Index(uri, strSlashSlash) if n < 0 { return strHTTP, host, uri } @@ -472,16 +805,30 @@ if bytes.IndexByte(scheme, '/') >= 0 { return strHTTP, host, uri } - n += len(strColonSlashSlash) + if len(scheme) > 0 && scheme[len(scheme)-1] == ':' { + scheme = scheme[:len(scheme)-1] + } + n += len(strSlashSlash) uri = uri[n:] n = bytes.IndexByte(uri, '/') - if n < 0 { + nq := bytes.IndexByte(uri, '?') + if nq >= 0 && nq < n { + // A hack for urls like foobar.com?a=b/xyz + n = nq + } else if n < 0 { + // A hack for bogus urls like foobar.com?a=b without + // slash after host. + if nq >= 0 { + return scheme, uri[:nq], uri[nq:] + } return scheme, uri, strSlash } return scheme, uri[:n], uri[n:] } // QueryArgs returns query args. +// +// The returned args are valid until the next URI method call. func (u *URI) QueryArgs() *Args { u.parseQueryArgs() return &u.queryArgs @@ -494,3 +841,14 @@ u.queryArgs.ParseBytes(u.queryString) u.parsedQueryArgs = true } + +// stringContainsCTLByte reports whether s contains any ASCII control character. +func stringContainsCTLByte(s []byte) bool { + for i := 0; i < len(s); i++ { + b := s[i] + if b < ' ' || b == 0x7f { + return true + } + } + return false +} diff -Nru golang-github-valyala-fasthttp-20160617/uri_test.go golang-github-valyala-fasthttp-1.31.0/uri_test.go --- golang-github-valyala-fasthttp-20160617/uri_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/uri_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -3,11 +3,14 @@ import ( "bytes" "fmt" + "reflect" "testing" "time" ) func TestURICopyToQueryArgs(t *testing.T) { + t.Parallel() + var u URI a := u.QueryArgs() a.Set("foo", "bar") @@ -22,10 +25,14 @@ } func TestURIAcquireReleaseSequential(t *testing.T) { + t.Parallel() + testURIAcquireRelease(t) } func TestURIAcquireReleaseConcurrent(t *testing.T) { + t.Parallel() + ch := make(chan struct{}, 10) for i := 0; i < 10; i++ { go func() { @@ -49,7 +56,7 @@ host := fmt.Sprintf("host.%d.com", i*23) path := fmt.Sprintf("/foo/%d/bar", i*17) queryArgs := "?foo=bar&baz=aass" - u.Parse([]byte(host), []byte(path+queryArgs)) + u.Parse([]byte(host), []byte(path+queryArgs)) //nolint:errcheck if string(u.Host()) != host { t.Fatalf("unexpected host %q. Expecting %q", u.Host(), host) } @@ -61,6 +68,8 @@ } func TestURILastPathSegment(t *testing.T) { + t.Parallel() + testURILastPathSegment(t, "", "") testURILastPathSegment(t, "/", "") testURILastPathSegment(t, "/foo/bar/", "") @@ -78,6 +87,8 @@ } func TestURIPathEscape(t *testing.T) { + t.Parallel() + testURIPathEscape(t, "/foo/bar", "/foo/bar") testURIPathEscape(t, "/f_o-o=b:ar,b.c&q", "/f_o-o=b:ar,b.c&q") testURIPathEscape(t, "/aa?bb.тест~qq", "/aa%3Fbb.%D1%82%D0%B5%D1%81%D1%82~qq") @@ -93,6 +104,8 @@ } func TestURIUpdate(t *testing.T) { + t.Parallel() + // full uri testURIUpdate(t, "http://foo.bar/baz?aaa=22#aaa", "https://aa.com/bb", "https://aa.com/bb") @@ -112,11 +125,15 @@ // hash testURIUpdate(t, "http://foo.bar/baz#aaa", "#fragment", "http://foo.bar/baz#fragment") + + // uri without scheme + testURIUpdate(t, "https://foo.bar/baz", "//aaa.bbb/cc?dd", "https://aaa.bbb/cc?dd") + testURIUpdate(t, "http://foo.bar/baz", "//aaa.bbb/cc?dd", "http://aaa.bbb/cc?dd") } func testURIUpdate(t *testing.T, base, update, result string) { var u URI - u.Parse(nil, []byte(base)) + u.Parse(nil, []byte(base)) //nolint:errcheck u.Update(update) s := u.String() if s != result { @@ -125,6 +142,8 @@ } func TestURIPathNormalize(t *testing.T) { + t.Parallel() + var u URI // double slash @@ -173,13 +192,45 @@ } func testURIPathNormalize(t *testing.T, u *URI, requestURI, expectedPath string) { - u.Parse(nil, []byte(requestURI)) + u.Parse(nil, []byte(requestURI)) //nolint:errcheck if string(u.Path()) != expectedPath { t.Fatalf("Unexpected path %q. Expected %q. requestURI=%q", u.Path(), expectedPath, requestURI) } } +func TestURINoNormalization(t *testing.T) { + t.Parallel() + + var u URI + irregularPath := "/aaa%2Fbbb%2F%2E.%2Fxxx" + u.Parse(nil, []byte(irregularPath)) //nolint:errcheck + u.DisablePathNormalizing = true + if string(u.RequestURI()) != irregularPath { + t.Fatalf("Unexpected path %q. Expected %q.", u.Path(), irregularPath) + } +} + +func TestURICopyTo(t *testing.T) { + t.Parallel() + + var u URI + var copyU URI + u.CopyTo(©U) + if !reflect.DeepEqual(u, copyU) { //nolint:govet + t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", u, copyU) //nolint:govet + } + + u.UpdateBytes([]byte("https://google.com/foo?bar=baz&baraz#qqqq")) + u.CopyTo(©U) + if !reflect.DeepEqual(u, copyU) { //nolint:govet + t.Fatalf("URICopyTo fail, u: \n%+v\ncopyu: \n%+v\n", u, copyU) //nolint:govet + } + +} + func TestURIFullURI(t *testing.T) { + t.Parallel() + var args Args // empty scheme, path and hash @@ -201,7 +252,7 @@ // test with empty args and non-empty query string var u URI - u.Parse([]byte("google.com"), []byte("/foo?bar=baz&baraz#qqqq")) + u.Parse([]byte("google.com"), []byte("/foo?bar=baz&baraz#qqqq")) //nolint:errcheck uri := u.FullURI() expectedURI := "http://google.com/foo?bar=baz&baraz#qqqq" if string(uri) != expectedURI { @@ -225,22 +276,43 @@ } func TestURIParseNilHost(t *testing.T) { - testURIParseScheme(t, "http://google.com/foo?bar#baz", "http") - testURIParseScheme(t, "HTtP://google.com/", "http") - testURIParseScheme(t, "://google.com/", "http") - testURIParseScheme(t, "fTP://aaa.com", "ftp") - testURIParseScheme(t, "httPS://aaa.com", "https") + t.Parallel() + + testURIParseScheme(t, "http://google.com/foo?bar#baz", "http", "google.com", "/foo?bar", "baz") + testURIParseScheme(t, "HTtP://google.com/", "http", "google.com", "/", "") + testURIParseScheme(t, "://google.com/xyz", "http", "google.com", "/xyz", "") + testURIParseScheme(t, "//google.com/foobar", "http", "google.com", "/foobar", "") + testURIParseScheme(t, "fTP://aaa.com", "ftp", "aaa.com", "/", "") + testURIParseScheme(t, "httPS://aaa.com", "https", "aaa.com", "/", "") + + // missing slash after hostname + testURIParseScheme(t, "http://foobar.com?baz=111", "http", "foobar.com", "/?baz=111", "") + + // slash in args + testURIParseScheme(t, "http://foobar.com?baz=111/222/xyz", "http", "foobar.com", "/?baz=111/222/xyz", "") + testURIParseScheme(t, "http://foobar.com?111/222/xyz", "http", "foobar.com", "/?111/222/xyz", "") } -func testURIParseScheme(t *testing.T, uri, expectedScheme string) { +func testURIParseScheme(t *testing.T, uri, expectedScheme, expectedHost, expectedRequestURI, expectedHash string) { var u URI - u.Parse(nil, []byte(uri)) + u.Parse(nil, []byte(uri)) //nolint:errcheck if string(u.Scheme()) != expectedScheme { - t.Fatalf("Unexpected scheme %q. Expected %q for uri %q", u.Scheme(), expectedScheme, uri) + t.Fatalf("Unexpected scheme %q. Expecting %q for uri %q", u.Scheme(), expectedScheme, uri) + } + if string(u.Host()) != expectedHost { + t.Fatalf("Unexepcted host %q. Expecting %q for uri %q", u.Host(), expectedHost, uri) + } + if string(u.RequestURI()) != expectedRequestURI { + t.Fatalf("Unexepcted requestURI %q. Expecting %q for uri %q", u.RequestURI(), expectedRequestURI, uri) + } + if string(u.hash) != expectedHash { + t.Fatalf("Unexepcted hash %q. Expecting %q for uri %q", u.hash, expectedHash, uri) } } func TestURIParse(t *testing.T) { + t.Parallel() + var u URI // no args @@ -261,7 +333,7 @@ // encoded path testURIParse(t, &u, "aa.com", "/Test%20+%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf", - "http://aa.com/Test%20%2B%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf", "aa.com", "/Test + при", "/Test%20+%20%D0%BF%D1%80%D0%B8", "asdf=%20%20&s=12", "sdf") + "http://aa.com/Test%20+%20%D0%BF%D1%80%D0%B8?asdf=%20%20&s=12#sdf", "aa.com", "/Test + при", "/Test%20+%20%D0%BF%D1%80%D0%B8", "asdf=%20%20&s=12", "sdf") // host in uppercase testURIParse(t, &u, "FOObar.COM", "/bC?De=F#Gh", @@ -284,11 +356,32 @@ // http:// in query params testURIParse(t, &u, "aaa.com", "/foo?bar=http://google.com", "http://aaa.com/foo?bar=http://google.com", "aaa.com", "/foo", "/foo", "bar=http://google.com", "") + + testURIParse(t, &u, "aaa.com", "//relative", + "http://aaa.com/relative", "aaa.com", "/relative", "//relative", "", "") + + testURIParse(t, &u, "", "//aaa.com//absolute", + "http://aaa.com/absolute", "aaa.com", "/absolute", "//absolute", "", "") + + testURIParse(t, &u, "", "//aaa.com\r\n\r\nGET x", + "http:///", "", "/", "", "", "") + + testURIParse(t, &u, "", "http://[fe80::1%25en0]/", + "http://[fe80::1%en0]/", "[fe80::1%en0]", "/", "/", "", "") + + testURIParse(t, &u, "", "http://[fe80::1%25en0]:8080/", + "http://[fe80::1%en0]:8080/", "[fe80::1%en0]:8080", "/", "/", "", "") + + testURIParse(t, &u, "", "http://hello.世界.com/foo", + "http://hello.世界.com/foo", "hello.世界.com", "/foo", "/foo", "", "") + + testURIParse(t, &u, "", "http://hello.%e4%b8%96%e7%95%8c.com/foo", + "http://hello.世界.com/foo", "hello.世界.com", "/foo", "/foo", "", "") } func testURIParse(t *testing.T, u *URI, host, uri, expectedURI, expectedHost, expectedPath, expectedPathOriginal, expectedArgs, expectedHash string) { - u.Parse([]byte(host), []byte(uri)) + u.Parse([]byte(host), []byte(uri)) //nolint:errcheck if !bytes.Equal(u.FullURI(), []byte(expectedURI)) { t.Fatalf("Unexpected uri %q. Expected %q. host=%q, uri=%q", u.FullURI(), expectedURI, host, uri) @@ -309,3 +402,48 @@ t.Fatalf("Unexpected hash %q. Expected %q. host=%q, uri=%q", u.Hash(), expectedHash, host, uri) } } + +func TestURIWithQuerystringOverride(t *testing.T) { + t.Parallel() + + var u URI + u.SetQueryString("q1=foo&q2=bar") + u.QueryArgs().Add("q3", "baz") + u.SetQueryString("q1=foo&q2=bar&q4=quux") + uriString := string(u.RequestURI()) + + if uriString != "/?q1=foo&q2=bar&q4=quux" { + t.Fatalf("Expected Querystring to be overridden but was %s ", uriString) + } +} + +func TestInvalidUrl(t *testing.T) { + url := `https://.çèéà@&~!&:=\\/\"'~<>|+-*()[]{}%$;,¥&&$22|||<>< 4ly8lzjmoNx233AXELDtyaFQiiUH-fd8c-CnXUJVYnGIs4Uwr-bptom5GCnWtsGMQxeM2ZhoKE973eKgs2Sjh6RePnyaLpCi6SiNSLevcMoraARrp88L-SgtKqd-XHAtSI8hiPRiXPQmDIA4BGhSgoc0nfn1PoYuGKKmDcZ04tANRc3iz4aF4-A1UrO8bLHTH7MEJvzx.someqa.fr/A/?&QS_BEGIN<&8{b'Ob=p*f> QS_END` + + u := AcquireURI() + defer ReleaseURI(u) + + if err := u.Parse(nil, []byte(url)); err == nil { + t.Fail() + } +} + +func TestNoOverwriteInput(t *testing.T) { + str := `//%AA` + url := []byte(str) + + u := AcquireURI() + defer ReleaseURI(u) + + if err := u.Parse(nil, url); err != nil { + t.Error(err) + } + + if string(url) != str { + t.Error() + } + + if u.String() != "http://\xaa/" { + t.Errorf("%q", u.String()) + } +} diff -Nru golang-github-valyala-fasthttp-20160617/uri_timing_test.go golang-github-valyala-fasthttp-1.31.0/uri_timing_test.go --- golang-github-valyala-fasthttp-20160617/uri_timing_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/uri_timing_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -27,7 +27,7 @@ b.RunParallel(func(pb *testing.PB) { var u URI - u.Parse(host, requestURI) + u.Parse(host, requestURI) //nolint:errcheck for pb.Next() { uri := u.FullURI() if len(uri) != uriLen { @@ -43,7 +43,7 @@ b.RunParallel(func(pb *testing.PB) { var u URI for pb.Next() { - u.Parse(strHost, strURI) + u.Parse(strHost, strURI) //nolint:errcheck } }) } diff -Nru golang-github-valyala-fasthttp-20160617/uri_unix.go golang-github-valyala-fasthttp-1.31.0/uri_unix.go --- golang-github-valyala-fasthttp-20160617/uri_unix.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/uri_unix.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,3 +1,4 @@ +//go:build !windows // +build !windows package fasthttp diff -Nru golang-github-valyala-fasthttp-20160617/uri_windows.go golang-github-valyala-fasthttp-1.31.0/uri_windows.go --- golang-github-valyala-fasthttp-20160617/uri_windows.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/uri_windows.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,3 +1,4 @@ +//go:build windows // +build windows package fasthttp diff -Nru golang-github-valyala-fasthttp-20160617/uri_windows_test.go golang-github-valyala-fasthttp-1.31.0/uri_windows_test.go --- golang-github-valyala-fasthttp-20160617/uri_windows_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/uri_windows_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,3 +1,4 @@ +//go:build windows // +build windows package fasthttp @@ -5,6 +6,8 @@ import "testing" func TestURIPathNormalizeIssue86(t *testing.T) { + t.Parallel() + // see https://github.com/valyala/fasthttp/issues/86 var u URI diff -Nru golang-github-valyala-fasthttp-20160617/userdata.go golang-github-valyala-fasthttp-1.31.0/userdata.go --- golang-github-valyala-fasthttp-20160617/userdata.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/userdata.go 2021-10-09 18:39:05.000000000 +0000 @@ -22,6 +22,10 @@ } } + if value == nil { + return + } + c := cap(args) if c > n { args = args[:n+1] @@ -69,3 +73,23 @@ } *d = (*d)[:0] } + +func (d *userData) Remove(key string) { + args := *d + n := len(args) + for i := 0; i < n; i++ { + kv := &args[i] + if string(kv.key) == key { + n-- + args[i] = args[n] + args[n].value = nil + args = args[:n] + *d = args + return + } + } +} + +func (d *userData) RemoveBytes(key []byte) { + d.Remove(b2s(key)) +} diff -Nru golang-github-valyala-fasthttp-20160617/userdata_test.go golang-github-valyala-fasthttp-1.31.0/userdata_test.go --- golang-github-valyala-fasthttp-20160617/userdata_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/userdata_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -7,6 +7,8 @@ ) func TestUserData(t *testing.T) { + t.Parallel() + var u userData for i := 0; i < 10; i++ { @@ -41,6 +43,8 @@ } func TestUserDataValueClose(t *testing.T) { + t.Parallel() + var u userData closeCalls := 0 @@ -72,3 +76,31 @@ (*cv.closeCalls)++ return nil } + +func TestUserDataDelete(t *testing.T) { + t.Parallel() + + var u userData + + for i := 0; i < 10; i++ { + key := fmt.Sprintf("key_%d", i) + u.Set(key, i) + testUserDataGet(t, &u, []byte(key), i) + } + + for i := 0; i < 10; i += 2 { + k := fmt.Sprintf("key_%d", i) + u.Remove(k) + if val := u.Get(k); val != nil { + t.Fatalf("unexpected key= %s, value =%v ,Expecting key= %s, value = nil", k, val, k) + } + kk := fmt.Sprintf("key_%d", i+1) + testUserDataGet(t, &u, []byte(kk), i+1) + } + for i := 0; i < 10; i++ { + key := fmt.Sprintf("key_new_%d", i) + u.Set(key, i) + testUserDataGet(t, &u, []byte(key), i) + } + +} diff -Nru golang-github-valyala-fasthttp-20160617/workerpool.go golang-github-valyala-fasthttp-1.31.0/workerpool.go --- golang-github-valyala-fasthttp-20160617/workerpool.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/workerpool.go 2021-10-09 18:39:05.000000000 +0000 @@ -3,7 +3,6 @@ import ( "net" "runtime" - "runtime/debug" "strings" "sync" "time" @@ -17,7 +16,7 @@ type workerPool struct { // Function for serving server connections. // It must leave c unclosed. - WorkerFunc func(c net.Conn) error + WorkerFunc ServeHandler MaxWorkersCount int @@ -36,6 +35,8 @@ stopCh chan struct{} workerChanPool sync.Pool + + connState func(net.Conn, ConnState) } type workerChan struct { @@ -49,6 +50,11 @@ } wp.stopCh = make(chan struct{}) stopCh := wp.stopCh + wp.workerChanPool.New = func() interface{} { + return &workerChan{ + ch: make(chan net.Conn, workerChanCap), + } + } go func() { var scratch []*workerChan for { @@ -75,8 +81,8 @@ // serving the connection and noticing wp.mustStop = true. wp.lock.Lock() ready := wp.ready - for i, ch := range ready { - ch.ch <- nil + for i := range ready { + ready[i].ch <- nil ready[i] = nil } wp.ready = ready[:0] @@ -96,23 +102,34 @@ // Clean least recently used workers if they didn't serve connections // for more than maxIdleWorkerDuration. - currentTime := time.Now() + criticalTime := time.Now().Add(-maxIdleWorkerDuration) wp.lock.Lock() ready := wp.ready n := len(ready) - i := 0 - for i < n && currentTime.Sub(ready[i].lastUseTime) > maxIdleWorkerDuration { - i++ - } - *scratch = append((*scratch)[:0], ready[:i]...) - if i > 0 { - m := copy(ready, ready[i:]) - for i = m; i < n; i++ { - ready[i] = nil + + // Use binary-search algorithm to find out the index of the least recently worker which can be cleaned up. + l, r, mid := 0, n-1, 0 + for l <= r { + mid = (l + r) / 2 + if criticalTime.After(wp.ready[mid].lastUseTime) { + l = mid + 1 + } else { + r = mid - 1 } - wp.ready = ready[:m] } + i := r + if i == -1 { + wp.lock.Unlock() + return + } + + *scratch = append((*scratch)[:0], ready[:i+1]...) + m := copy(ready, ready[i+1:]) + for i = m; i < n; i++ { + ready[i] = nil + } + wp.ready = ready[:m] wp.lock.Unlock() // Notify obsolete workers to stop. @@ -120,8 +137,8 @@ // may be blocking and may consume a lot of time if many workers // are located on non-local CPUs. tmp := *scratch - for i, ch := range tmp { - ch.ch <- nil + for i := range tmp { + tmp[i].ch <- nil tmp[i] = nil } } @@ -173,11 +190,6 @@ return nil } vch := wp.workerChanPool.Get() - if vch == nil { - vch = &workerChan{ - ch: make(chan net.Conn, workerChanCap), - } - } ch = vch.(*workerChan) go func() { wp.workerFunc(ch) @@ -202,19 +214,6 @@ func (wp *workerPool) workerFunc(ch *workerChan) { var c net.Conn - defer func() { - if r := recover(); r != nil { - wp.Logger.Printf("panic: %s\nStack trace:\n%s", r, debug.Stack()) - if c != nil { - c.Close() - } - } - - wp.lock.Lock() - wp.workersCount-- - wp.lock.Unlock() - }() - var err error for c = range ch.ch { if c == nil { @@ -225,12 +224,17 @@ errStr := err.Error() if wp.LogAllErrors || !(strings.Contains(errStr, "broken pipe") || strings.Contains(errStr, "reset by peer") || + strings.Contains(errStr, "request headers: small read buffer") || + strings.Contains(errStr, "unexpected EOF") || strings.Contains(errStr, "i/o timeout")) { wp.Logger.Printf("error when serving connection %q<->%q: %s", c.LocalAddr(), c.RemoteAddr(), err) } } - if err != errHijacked { - c.Close() + if err == errHijacked { + wp.connState(c, StateHijacked) + } else { + _ = c.Close() + wp.connState(c, StateClosed) } c = nil @@ -238,4 +242,8 @@ break } } + + wp.lock.Lock() + wp.workersCount-- + wp.lock.Unlock() } diff -Nru golang-github-valyala-fasthttp-20160617/workerpool_test.go golang-github-valyala-fasthttp-1.31.0/workerpool_test.go --- golang-github-valyala-fasthttp-20160617/workerpool_test.go 2016-06-17 10:13:04.000000000 +0000 +++ golang-github-valyala-fasthttp-1.31.0/workerpool_test.go 2021-10-09 18:39:05.000000000 +0000 @@ -1,10 +1,8 @@ package fasthttp import ( - "fmt" "io/ioutil" "net" - "sync/atomic" "testing" "time" @@ -12,10 +10,14 @@ ) func TestWorkerPoolStartStopSerial(t *testing.T) { + t.Parallel() + testWorkerPoolStartStop(t) } func TestWorkerPoolStartStopConcurrent(t *testing.T) { + t.Parallel() + concurrency := 10 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { @@ -46,10 +48,14 @@ } func TestWorkerPoolMaxWorkersCountSerial(t *testing.T) { + t.Parallel() + testWorkerPoolMaxWorkersCountMulti(t) } func TestWorkerPoolMaxWorkersCountConcurrent(t *testing.T) { + t.Parallel() + concurrency := 4 ch := make(chan struct{}, concurrency) for i := 0; i < concurrency; i++ { @@ -61,7 +67,7 @@ for i := 0; i < concurrency; i++ { select { case <-ch: - case <-time.After(time.Second): + case <-time.After(time.Second * 2): t.Fatalf("timeout") } } @@ -80,14 +86,14 @@ buf := make([]byte, 100) n, err := conn.Read(buf) if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } buf = buf[:n] if string(buf) != "foobar" { - t.Fatalf("unexpected data read: %q. Expecting %q", buf, "foobar") + t.Errorf("unexpected data read: %q. Expecting %q", buf, "foobar") } if _, err = conn.Write([]byte("baz")); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } <-ready @@ -96,6 +102,7 @@ }, MaxWorkersCount: 10, Logger: defaultLogger, + connState: func(net.Conn, ConnState) {}, } wp.Start() @@ -106,20 +113,20 @@ go func() { conn, err := ln.Dial() if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } if _, err = conn.Write([]byte("foobar")); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } data, err := ioutil.ReadAll(conn) if err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } if string(data) != "baz" { - t.Fatalf("unexpected value read: %q. Expecting %q", data, "baz") + t.Errorf("unexpected value read: %q. Expecting %q", data, "baz") } if err = conn.Close(); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } clientCh <- struct{}{} }() @@ -137,7 +144,7 @@ go func() { if _, err := ln.Dial(); err != nil { - t.Fatalf("unexpected error: %s", err) + t.Errorf("unexpected error: %s", err) } }() conn, err := ln.Accept() @@ -168,100 +175,3 @@ } wp.Stop() } - -func TestWorkerPoolPanicErrorSerial(t *testing.T) { - testWorkerPoolPanicErrorMulti(t) -} - -func TestWorkerPoolPanicErrorConcurrent(t *testing.T) { - concurrency := 10 - ch := make(chan struct{}, concurrency) - for i := 0; i < concurrency; i++ { - go func() { - testWorkerPoolPanicErrorMulti(t) - ch <- struct{}{} - }() - } - for i := 0; i < concurrency; i++ { - select { - case <-ch: - case <-time.After(time.Second): - t.Fatalf("timeout") - } - } -} - -func testWorkerPoolPanicErrorMulti(t *testing.T) { - var globalCount uint64 - wp := &workerPool{ - WorkerFunc: func(conn net.Conn) error { - count := atomic.AddUint64(&globalCount, 1) - switch count % 3 { - case 0: - panic("foobar") - case 1: - return fmt.Errorf("fake error") - } - return nil - }, - MaxWorkersCount: 1000, - MaxIdleWorkerDuration: time.Millisecond, - Logger: &customLogger{}, - } - - for i := 0; i < 10; i++ { - testWorkerPoolPanicError(t, wp) - } -} - -func testWorkerPoolPanicError(t *testing.T, wp *workerPool) { - wp.Start() - - ln := fasthttputil.NewInmemoryListener() - - clientsCount := 10 - clientCh := make(chan struct{}, clientsCount) - for i := 0; i < clientsCount; i++ { - go func() { - conn, err := ln.Dial() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - data, err := ioutil.ReadAll(conn) - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if len(data) > 0 { - t.Fatalf("unexpected data read: %q. Expecting empty data", data) - } - if err = conn.Close(); err != nil { - t.Fatalf("unexpected error: %s", err) - } - clientCh <- struct{}{} - }() - } - - for i := 0; i < clientsCount; i++ { - conn, err := ln.Accept() - if err != nil { - t.Fatalf("unexpected error: %s", err) - } - if !wp.Serve(conn) { - t.Fatalf("worker pool mustn't be full") - } - } - - for i := 0; i < clientsCount; i++ { - select { - case <-clientCh: - case <-time.After(time.Second): - t.Fatalf("timeout") - } - } - - if err := ln.Close(); err != nil { - t.Fatalf("unexpected error: %s", err) - } - - wp.Stop() -}