diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/array.go golang-github-lib-pq-1.3.0/array.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/array.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/array.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,756 @@ +package pq + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/hex" + "fmt" + "reflect" + "strconv" + "strings" +) + +var typeByteSlice = reflect.TypeOf([]byte{}) +var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem() +var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem() + +// Array returns the optimal driver.Valuer and sql.Scanner for an array or +// slice of any dimension. +// +// For example: +// db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401})) +// +// var x []sql.NullInt64 +// db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x)) +// +// Scanning multi-dimensional arrays is not supported. Arrays where the lower +// bound is not one (such as `[0:0]={1}') are not supported. +func Array(a interface{}) interface { + driver.Valuer + sql.Scanner +} { + switch a := a.(type) { + case []bool: + return (*BoolArray)(&a) + case []float64: + return (*Float64Array)(&a) + case []int64: + return (*Int64Array)(&a) + case []string: + return (*StringArray)(&a) + + case *[]bool: + return (*BoolArray)(a) + case *[]float64: + return (*Float64Array)(a) + case *[]int64: + return (*Int64Array)(a) + case *[]string: + return (*StringArray)(a) + } + + return GenericArray{a} +} + +// ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner +// to override the array delimiter used by GenericArray. +type ArrayDelimiter interface { + // ArrayDelimiter returns the delimiter character(s) for this element's type. + ArrayDelimiter() string +} + +// BoolArray represents a one-dimensional array of the PostgreSQL boolean type. +type BoolArray []bool + +// Scan implements the sql.Scanner interface. +func (a *BoolArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to BoolArray", src) +} + +func (a *BoolArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "BoolArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(BoolArray, len(elems)) + for i, v := range elems { + if len(v) != 1 { + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + switch v[0] { + case 't': + b[i] = true + case 'f': + b[i] = false + default: + return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a BoolArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be exactly two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1+2*n) + + for i := 0; i < n; i++ { + b[2*i] = ',' + if a[i] { + b[1+2*i] = 't' + } else { + b[1+2*i] = 'f' + } + } + + b[0] = '{' + b[2*n] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// ByteaArray represents a one-dimensional array of the PostgreSQL bytea type. +type ByteaArray [][]byte + +// Scan implements the sql.Scanner interface. +func (a *ByteaArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to ByteaArray", src) +} + +func (a *ByteaArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "ByteaArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(ByteaArray, len(elems)) + for i, v := range elems { + b[i], err = parseBytea(v) + if err != nil { + return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error()) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. It uses the "hex" format which +// is only supported on PostgreSQL 9.0 or newer. +func (a ByteaArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // 3*N bytes of hex formatting, and N-1 bytes of delimiters. + size := 1 + 6*n + for _, x := range a { + size += hex.EncodedLen(len(x)) + } + + b := make([]byte, size) + + for i, s := 0, b; i < n; i++ { + o := copy(s, `,"\\x`) + o += hex.Encode(s[o:], a[i]) + s[o] = '"' + s = s[o+1:] + } + + b[0] = '{' + b[size-1] = '}' + + return string(b), nil + } + + return "{}", nil +} + +// Float64Array represents a one-dimensional array of the PostgreSQL double +// precision type. +type Float64Array []float64 + +// Scan implements the sql.Scanner interface. +func (a *Float64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Float64Array", src) +} + +func (a *Float64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Float64Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Float64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseFloat(string(v), 64); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Float64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendFloat(b, a[0], 'f', -1, 64) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendFloat(b, a[i], 'f', -1, 64) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// GenericArray implements the driver.Valuer and sql.Scanner interfaces for +// an array or slice of any dimension. +type GenericArray struct{ A interface{} } + +func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) { + var assign func([]byte, reflect.Value) error + var del = "," + + // TODO calculate the assign function for other types + // TODO repeat this section on the element type of arrays or slices (multidimensional) + { + if reflect.PtrTo(rt).Implements(typeSQLScanner) { + // dest is always addressable because it is an element of a slice. + assign = func(src []byte, dest reflect.Value) (err error) { + ss := dest.Addr().Interface().(sql.Scanner) + if src == nil { + err = ss.Scan(nil) + } else { + err = ss.Scan(src) + } + return + } + goto FoundType + } + + assign = func([]byte, reflect.Value) error { + return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt) + } + } + +FoundType: + + if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + return rt, assign, del +} + +// Scan implements the sql.Scanner interface. +func (a GenericArray) Scan(src interface{}) error { + dpv := reflect.ValueOf(a.A) + switch { + case dpv.Kind() != reflect.Ptr: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + case dpv.IsNil(): + return fmt.Errorf("pq: destination %T is nil", a.A) + } + + dv := dpv.Elem() + switch dv.Kind() { + case reflect.Slice: + case reflect.Array: + default: + return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A) + } + + switch src := src.(type) { + case []byte: + return a.scanBytes(src, dv) + case string: + return a.scanBytes([]byte(src), dv) + case nil: + if dv.Kind() == reflect.Slice { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + } + + return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type()) +} + +func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error { + dtype, assign, del := a.evaluateDestination(dv.Type().Elem()) + dims, elems, err := parseArray(src, []byte(del)) + if err != nil { + return err + } + + // TODO allow multidimensional + + if len(dims) > 1 { + return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented", + strings.Replace(fmt.Sprint(dims), " ", "][", -1)) + } + + // Treat a zero-dimensional array like an array with a single dimension of zero. + if len(dims) == 0 { + dims = append(dims, 0) + } + + for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() { + switch rt.Kind() { + case reflect.Slice: + case reflect.Array: + if rt.Len() != dims[i] { + return fmt.Errorf("pq: cannot convert ARRAY%s to %s", + strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type()) + } + default: + // TODO handle multidimensional + } + } + + values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems)) + for i, e := range elems { + if err := assign(e, values.Index(i)); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + + // TODO handle multidimensional + + switch dv.Kind() { + case reflect.Slice: + dv.Set(values.Slice(0, dims[0])) + case reflect.Array: + for i := 0; i < dims[0]; i++ { + dv.Index(i).Set(values.Index(i)) + } + } + + return nil +} + +// Value implements the driver.Valuer interface. +func (a GenericArray) Value() (driver.Value, error) { + if a.A == nil { + return nil, nil + } + + rv := reflect.ValueOf(a.A) + + switch rv.Kind() { + case reflect.Slice: + if rv.IsNil() { + return nil, nil + } + case reflect.Array: + default: + return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A) + } + + if n := rv.Len(); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 0, 1+2*n) + + b, _, err := appendArray(b, rv, n) + return string(b), err + } + + return "{}", nil +} + +// Int64Array represents a one-dimensional array of the PostgreSQL integer types. +type Int64Array []int64 + +// Scan implements the sql.Scanner interface. +func (a *Int64Array) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to Int64Array", src) +} + +func (a *Int64Array) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "Int64Array") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(Int64Array, len(elems)) + for i, v := range elems { + if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil { + return fmt.Errorf("pq: parsing array element index %d: %v", i, err) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a Int64Array) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, N bytes of values, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+2*n) + b[0] = '{' + + b = strconv.AppendInt(b, a[0], 10) + for i := 1; i < n; i++ { + b = append(b, ',') + b = strconv.AppendInt(b, a[i], 10) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// StringArray represents a one-dimensional array of the PostgreSQL character types. +type StringArray []string + +// Scan implements the sql.Scanner interface. +func (a *StringArray) Scan(src interface{}) error { + switch src := src.(type) { + case []byte: + return a.scanBytes(src) + case string: + return a.scanBytes([]byte(src)) + case nil: + *a = nil + return nil + } + + return fmt.Errorf("pq: cannot convert %T to StringArray", src) +} + +func (a *StringArray) scanBytes(src []byte) error { + elems, err := scanLinearArray(src, []byte{','}, "StringArray") + if err != nil { + return err + } + if *a != nil && len(elems) == 0 { + *a = (*a)[:0] + } else { + b := make(StringArray, len(elems)) + for i, v := range elems { + if b[i] = string(v); v == nil { + return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i) + } + } + *a = b + } + return nil +} + +// Value implements the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + + if n := len(a); n > 0 { + // There will be at least two curly brackets, 2*N bytes of quotes, + // and N-1 bytes of delimiters. + b := make([]byte, 1, 1+3*n) + b[0] = '{' + + b = appendArrayQuotedBytes(b, []byte(a[0])) + for i := 1; i < n; i++ { + b = append(b, ',') + b = appendArrayQuotedBytes(b, []byte(a[i])) + } + + return string(append(b, '}')), nil + } + + return "{}", nil +} + +// appendArray appends rv to the buffer, returning the extended buffer and +// the delimiter used between elements. +// +// It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice. +func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) { + var del string + var err error + + b = append(b, '{') + + if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil { + return b, del, err + } + + for i := 1; i < n; i++ { + b = append(b, del...) + if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil { + return b, del, err + } + } + + return append(b, '}'), del, nil +} + +// appendArrayElement appends rv to the buffer, returning the extended buffer +// and the delimiter to use before the next element. +// +// When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted +// using driver.DefaultParameterConverter and the resulting []byte or string +// is double-quoted. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) { + if k := rv.Kind(); k == reflect.Array || k == reflect.Slice { + if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) { + if n := rv.Len(); n > 0 { + return appendArray(b, rv, n) + } + + return b, "", nil + } + } + + var del = "," + var err error + var iv interface{} = rv.Interface() + + if ad, ok := iv.(ArrayDelimiter); ok { + del = ad.ArrayDelimiter() + } + + if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil { + return b, del, err + } + + switch v := iv.(type) { + case nil: + return append(b, "NULL"...), del, nil + case []byte: + return appendArrayQuotedBytes(b, v), del, nil + case string: + return appendArrayQuotedBytes(b, []byte(v)), del, nil + } + + b, err = appendValue(b, iv) + return b, del, err +} + +func appendArrayQuotedBytes(b, v []byte) []byte { + b = append(b, '"') + for { + i := bytes.IndexAny(v, `"\`) + if i < 0 { + b = append(b, v...) + break + } + if i > 0 { + b = append(b, v[:i]...) + } + b = append(b, '\\', v[i]) + v = v[i+1:] + } + return append(b, '"') +} + +func appendValue(b []byte, v driver.Value) ([]byte, error) { + return append(b, encode(nil, v, 0)...), nil +} + +// parseArray extracts the dimensions and elements of an array represented in +// text format. Only representations emitted by the backend are supported. +// Notably, whitespace around brackets and delimiters is significant, and NULL +// is case-sensitive. +// +// See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO +func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) { + var depth, i int + + if len(src) < 1 || src[0] != '{' { + return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0) + } + +Open: + for i < len(src) { + switch src[i] { + case '{': + depth++ + i++ + case '}': + elems = make([][]byte, 0) + goto Close + default: + break Open + } + } + dims = make([]int, i) + +Element: + for i < len(src) { + switch src[i] { + case '{': + if depth == len(dims) { + break Element + } + depth++ + dims[depth-1] = 0 + i++ + case '"': + var elem = []byte{} + var escape bool + for i++; i < len(src); i++ { + if escape { + elem = append(elem, src[i]) + escape = false + } else { + switch src[i] { + default: + elem = append(elem, src[i]) + case '\\': + escape = true + case '"': + elems = append(elems, elem) + i++ + break Element + } + } + } + default: + for start := i; i < len(src); i++ { + if bytes.HasPrefix(src[i:], del) || src[i] == '}' { + elem := src[start:i] + if len(elem) == 0 { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + if bytes.Equal(elem, []byte("NULL")) { + elem = nil + } + elems = append(elems, elem) + break Element + } + } + } + } + + for i < len(src) { + if bytes.HasPrefix(src[i:], del) && depth > 0 { + dims[depth-1]++ + i += len(del) + goto Element + } else if src[i] == '}' && depth > 0 { + dims[depth-1]++ + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + +Close: + for i < len(src) { + if src[i] == '}' && depth > 0 { + depth-- + i++ + } else { + return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i) + } + } + if depth > 0 { + err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i) + } + if err == nil { + for _, d := range dims { + if (len(elems) % d) != 0 { + err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions") + } + } + } + return +} + +func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) { + dims, elems, err := parseArray(src, del) + if err != nil { + return nil, err + } + if len(dims) > 1 { + return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ) + } + return elems, err +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/array_test.go golang-github-lib-pq-1.3.0/array_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/array_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/array_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,1311 @@ +package pq + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "math/rand" + "reflect" + "strings" + "testing" +) + +func TestParseArray(t *testing.T) { + for _, tt := range []struct { + input string + delim string + dims []int + elems [][]byte + }{ + {`{}`, `,`, nil, [][]byte{}}, + {`{NULL}`, `,`, []int{1}, [][]byte{nil}}, + {`{a}`, `,`, []int{1}, [][]byte{{'a'}}}, + {`{a,b}`, `,`, []int{2}, [][]byte{{'a'}, {'b'}}}, + {`{{a,b}}`, `,`, []int{1, 2}, [][]byte{{'a'}, {'b'}}}, + {`{{a},{b}}`, `,`, []int{2, 1}, [][]byte{{'a'}, {'b'}}}, + {`{{{a,b},{c,d},{e,f}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + }}, + {`{""}`, `,`, []int{1}, [][]byte{{}}}, + {`{","}`, `,`, []int{1}, [][]byte{{','}}}, + {`{",",","}`, `,`, []int{2}, [][]byte{{','}, {','}}}, + {`{{",",","}}`, `,`, []int{1, 2}, [][]byte{{','}, {','}}}, + {`{{","},{","}}`, `,`, []int{2, 1}, [][]byte{{','}, {','}}}, + {`{{{",",","},{",",","},{",",","}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {','}, {','}, {','}, {','}, {','}, {','}, + }}, + {`{"\"}"}`, `,`, []int{1}, [][]byte{{'"', '}'}}}, + {`{"\"","\""}`, `,`, []int{2}, [][]byte{{'"'}, {'"'}}}, + {`{{"\"","\""}}`, `,`, []int{1, 2}, [][]byte{{'"'}, {'"'}}}, + {`{{"\""},{"\""}}`, `,`, []int{2, 1}, [][]byte{{'"'}, {'"'}}}, + {`{{{"\"","\""},{"\"","\""},{"\"","\""}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {'"'}, {'"'}, {'"'}, {'"'}, {'"'}, {'"'}, + }}, + {`{axyzb}`, `xyz`, []int{2}, [][]byte{{'a'}, {'b'}}}, + } { + dims, elems, err := parseArray([]byte(tt.input), []byte(tt.delim)) + + if err != nil { + t.Fatalf("Expected no error for %q, got %q", tt.input, err) + } + if !reflect.DeepEqual(dims, tt.dims) { + t.Errorf("Expected %v dimensions for %q, got %v", tt.dims, tt.input, dims) + } + if !reflect.DeepEqual(elems, tt.elems) { + t.Errorf("Expected %v elements for %q, got %v", tt.elems, tt.input, elems) + } + } +} + +func TestParseArrayError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "expected '{' at offset 0"}, + {`x`, "expected '{' at offset 0"}, + {`}`, "expected '{' at offset 0"}, + {`{`, "expected '}' at offset 1"}, + {`{{}`, "expected '}' at offset 3"}, + {`{}}`, "unexpected '}' at offset 2"}, + {`{,}`, "unexpected ',' at offset 1"}, + {`{,x}`, "unexpected ',' at offset 1"}, + {`{x,}`, "unexpected '}' at offset 3"}, + {`{x,{`, "unexpected '{' at offset 3"}, + {`{x},`, "unexpected ',' at offset 3"}, + {`{x}}`, "unexpected '}' at offset 3"}, + {`{{x}`, "expected '}' at offset 4"}, + {`{""x}`, "unexpected 'x' at offset 3"}, + {`{{a},{b,c}}`, "multidimensional arrays must have elements with matching dimensions"}, + } { + _, _, err := parseArray([]byte(tt.input), []byte{','}) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + } +} + +func TestArrayScanner(t *testing.T) { + var s sql.Scanner = Array(&[]bool{}) + if _, ok := s.(*BoolArray); !ok { + t.Errorf("Expected *BoolArray, got %T", s) + } + + s = Array(&[]float64{}) + if _, ok := s.(*Float64Array); !ok { + t.Errorf("Expected *Float64Array, got %T", s) + } + + s = Array(&[]int64{}) + if _, ok := s.(*Int64Array); !ok { + t.Errorf("Expected *Int64Array, got %T", s) + } + + s = Array(&[]string{}) + if _, ok := s.(*StringArray); !ok { + t.Errorf("Expected *StringArray, got %T", s) + } + + for _, tt := range []interface{}{ + &[]sql.Scanner{}, + &[][]bool{}, + &[][]float64{}, + &[][]int64{}, + &[][]string{}, + } { + s = Array(tt) + if _, ok := s.(GenericArray); !ok { + t.Errorf("Expected GenericArray for %T, got %T", tt, s) + } + } +} + +func TestArrayValuer(t *testing.T) { + var v driver.Valuer = Array([]bool{}) + if _, ok := v.(*BoolArray); !ok { + t.Errorf("Expected *BoolArray, got %T", v) + } + + v = Array([]float64{}) + if _, ok := v.(*Float64Array); !ok { + t.Errorf("Expected *Float64Array, got %T", v) + } + + v = Array([]int64{}) + if _, ok := v.(*Int64Array); !ok { + t.Errorf("Expected *Int64Array, got %T", v) + } + + v = Array([]string{}) + if _, ok := v.(*StringArray); !ok { + t.Errorf("Expected *StringArray, got %T", v) + } + + for _, tt := range []interface{}{ + nil, + []driver.Value{}, + [][]bool{}, + [][]float64{}, + [][]int64{}, + [][]string{}, + } { + v = Array(tt) + if _, ok := v.(GenericArray); !ok { + t.Errorf("Expected GenericArray for %T, got %T", tt, v) + } + } +} + +func TestBoolArrayScanUnsupported(t *testing.T) { + var arr BoolArray + err := arr.Scan(1) + + if err == nil { + t.Fatal("Expected error when scanning from int") + } + if !strings.Contains(err.Error(), "int to BoolArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestBoolArrayScanEmpty(t *testing.T) { + var arr BoolArray + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestBoolArrayScanNil(t *testing.T) { + arr := BoolArray{true, true, true} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var BoolArrayStringTests = []struct { + str string + arr BoolArray +}{ + {`{}`, BoolArray{}}, + {`{t}`, BoolArray{true}}, + {`{f,t}`, BoolArray{false, true}}, +} + +func TestBoolArrayScanBytes(t *testing.T) { + for _, tt := range BoolArrayStringTests { + bytes := []byte(tt.str) + arr := BoolArray{true, true, true} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkBoolArrayScanBytes(b *testing.B) { + var a BoolArray + var x interface{} = []byte(`{t,f,t,f,t,f,t,f,t,f}`) + + for i := 0; i < b.N; i++ { + a = BoolArray{} + a.Scan(x) + } +} + +func TestBoolArrayScanString(t *testing.T) { + for _, tt := range BoolArrayStringTests { + arr := BoolArray{true, true, true} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestBoolArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{t},{f}}`, "cannot convert ARRAY[2][1] to BoolArray"}, + {`{NULL}`, `could not parse boolean array index 0: invalid boolean ""`}, + {`{a}`, `could not parse boolean array index 0: invalid boolean "a"`}, + {`{t,b}`, `could not parse boolean array index 1: invalid boolean "b"`}, + {`{t,f,cd}`, `could not parse boolean array index 2: invalid boolean "cd"`}, + } { + arr := BoolArray{true, true, true} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, BoolArray{true, true, true}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestBoolArrayValue(t *testing.T) { + result, err := BoolArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = BoolArray([]bool{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = BoolArray([]bool{false, true, false}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{f,t,f}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkBoolArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]bool, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Intn(2) == 0 + } + a := BoolArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestByteaArrayScanUnsupported(t *testing.T) { + var arr ByteaArray + err := arr.Scan(1) + + if err == nil { + t.Fatal("Expected error when scanning from int") + } + if !strings.Contains(err.Error(), "int to ByteaArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestByteaArrayScanEmpty(t *testing.T) { + var arr ByteaArray + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestByteaArrayScanNil(t *testing.T) { + arr := ByteaArray{{2}, {6}, {0, 0}} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var ByteaArrayStringTests = []struct { + str string + arr ByteaArray +}{ + {`{}`, ByteaArray{}}, + {`{NULL}`, ByteaArray{nil}}, + {`{"\\xfeff"}`, ByteaArray{{'\xFE', '\xFF'}}}, + {`{"\\xdead","\\xbeef"}`, ByteaArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}}, +} + +func TestByteaArrayScanBytes(t *testing.T) { + for _, tt := range ByteaArrayStringTests { + bytes := []byte(tt.str) + arr := ByteaArray{{2}, {6}, {0, 0}} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkByteaArrayScanBytes(b *testing.B) { + var a ByteaArray + var x interface{} = []byte(`{"\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff"}`) + + for i := 0; i < b.N; i++ { + a = ByteaArray{} + a.Scan(x) + } +} + +func TestByteaArrayScanString(t *testing.T) { + for _, tt := range ByteaArrayStringTests { + arr := ByteaArray{{2}, {6}, {0, 0}} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestByteaArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{"\\xfeff"},{"\\xbeef"}}`, "cannot convert ARRAY[2][1] to ByteaArray"}, + {`{"\\abc"}`, "could not parse bytea array index 0: could not parse bytea value"}, + } { + arr := ByteaArray{{2}, {6}, {0, 0}} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, ByteaArray{{2}, {6}, {0, 0}}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestByteaArrayValue(t *testing.T) { + result, err := ByteaArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = ByteaArray([][]byte{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = ByteaArray([][]byte{{'\xDE', '\xAD', '\xBE', '\xEF'}, {'\xFE', '\xFF'}, {}}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{"\\xdeadbeef","\\xfeff","\\x"}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkByteaArrayValue(b *testing.B) { + rand.Seed(1) + x := make([][]byte, 10) + for i := 0; i < len(x); i++ { + x[i] = make([]byte, len(x)) + for j := 0; j < len(x); j++ { + x[i][j] = byte(rand.Int()) + } + } + a := ByteaArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestFloat64ArrayScanUnsupported(t *testing.T) { + var arr Float64Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Float64Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestFloat64ArrayScanEmpty(t *testing.T) { + var arr Float64Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestFloat64ArrayScanNil(t *testing.T) { + arr := Float64Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Float64ArrayStringTests = []struct { + str string + arr Float64Array +}{ + {`{}`, Float64Array{}}, + {`{1.2}`, Float64Array{1.2}}, + {`{3.456,7.89}`, Float64Array{3.456, 7.89}}, + {`{3,1,2}`, Float64Array{3, 1, 2}}, +} + +func TestFloat64ArrayScanBytes(t *testing.T) { + for _, tt := range Float64ArrayStringTests { + bytes := []byte(tt.str) + arr := Float64Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkFloat64ArrayScanBytes(b *testing.B) { + var a Float64Array + var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`) + + for i := 0; i < b.N; i++ { + a = Float64Array{} + a.Scan(x) + } +} + +func TestFloat64ArrayScanString(t *testing.T) { + for _, tt := range Float64ArrayStringTests { + arr := Float64Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestFloat64ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float64Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5.6,a}`, "parsing array element index 1:"}, + {`{5.6,7.8,a}`, "parsing array element index 2:"}, + } { + arr := Float64Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Float64Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestFloat64ArrayValue(t *testing.T) { + result, err := Float64Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Float64Array([]float64{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Float64Array([]float64{1.2, 3.4, 5.6}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkFloat64ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]float64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.NormFloat64() + } + a := Float64Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestInt64ArrayScanUnsupported(t *testing.T) { + var arr Int64Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Int64Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestInt64ArrayScanEmpty(t *testing.T) { + var arr Int64Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestInt64ArrayScanNil(t *testing.T) { + arr := Int64Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Int64ArrayStringTests = []struct { + str string + arr Int64Array +}{ + {`{}`, Int64Array{}}, + {`{12}`, Int64Array{12}}, + {`{345,678}`, Int64Array{345, 678}}, +} + +func TestInt64ArrayScanBytes(t *testing.T) { + for _, tt := range Int64ArrayStringTests { + bytes := []byte(tt.str) + arr := Int64Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkInt64ArrayScanBytes(b *testing.B) { + var a Int64Array + var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`) + + for i := 0; i < b.N; i++ { + a = Int64Array{} + a.Scan(x) + } +} + +func TestInt64ArrayScanString(t *testing.T) { + for _, tt := range Int64ArrayStringTests { + arr := Int64Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestInt64ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int64Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5,a}`, "parsing array element index 1:"}, + {`{5,6,a}`, "parsing array element index 2:"}, + } { + arr := Int64Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Int64Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestInt64ArrayValue(t *testing.T) { + result, err := Int64Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Int64Array([]int64{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Int64Array([]int64{1, 2, 3}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkInt64ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]int64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int63() + } + a := Int64Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestStringArrayScanUnsupported(t *testing.T) { + var arr StringArray + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to StringArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestStringArrayScanEmpty(t *testing.T) { + var arr StringArray + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestStringArrayScanNil(t *testing.T) { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var StringArrayStringTests = []struct { + str string + arr StringArray +}{ + {`{}`, StringArray{}}, + {`{t}`, StringArray{"t"}}, + {`{f,1}`, StringArray{"f", "1"}}, + {`{"a\\b","c d",","}`, StringArray{"a\\b", "c d", ","}}, +} + +func TestStringArrayScanBytes(t *testing.T) { + for _, tt := range StringArrayStringTests { + bytes := []byte(tt.str) + arr := StringArray{"x", "x", "x"} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkStringArrayScanBytes(b *testing.B) { + var a StringArray + var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`) + var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`) + + for i := 0; i < b.N; i++ { + a = StringArray{} + a.Scan(x) + a = StringArray{} + a.Scan(y) + } +} + +func TestStringArrayScanString(t *testing.T) { + for _, tt := range StringArrayStringTests { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestStringArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{a},{b}}`, "cannot convert ARRAY[2][1] to StringArray"}, + {`{NULL}`, "parsing array element index 0: cannot convert nil to string"}, + {`{a,NULL}`, "parsing array element index 1: cannot convert nil to string"}, + {`{a,b,NULL}`, "parsing array element index 2: cannot convert nil to string"}, + } { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, StringArray{"x", "x", "x"}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestStringArrayValue(t *testing.T) { + result, err := StringArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = StringArray([]string{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = StringArray([]string{`a`, `\b`, `c"`, `d,e`}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{"a","\\b","c\"","d,e"}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkStringArrayValue(b *testing.B) { + x := make([]string, 10) + for i := 0; i < len(x); i++ { + x[i] = strings.Repeat(`abc"def\ghi`, 5) + } + a := StringArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestGenericArrayScanUnsupported(t *testing.T) { + var s string + var ss []string + var nsa [1]sql.NullString + + for _, tt := range []struct { + src, dest interface{} + err string + }{ + {nil, nil, "destination is not a pointer to array or slice"}, + {nil, true, "destination bool is not a pointer to array or slice"}, + {nil, &s, "destination *string is not a pointer to array or slice"}, + {nil, ss, "destination []string is not a pointer to array or slice"}, + {nil, &nsa, " to [1]sql.NullString"}, + {true, &ss, "bool to []string"}, + {`{{x}}`, &ss, "multidimensional ARRAY[1][1] is not implemented"}, + {`{{x},{x}}`, &ss, "multidimensional ARRAY[2][1] is not implemented"}, + {`{x}`, &ss, "scanning to string is not implemented"}, + } { + err := GenericArray{tt.dest}.Scan(tt.src) + + if err == nil { + t.Fatalf("Expected error for [%#v %#v]", tt.src, tt.dest) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for [%#v %#v], got %q", tt.err, tt.src, tt.dest, err) + } + } +} + +func TestGenericArrayScanScannerArrayBytes(t *testing.T) { + src, expected, nsa := []byte(`{NULL,abc,"\""}`), + [3]sql.NullString{{}, {String: `abc`, Valid: true}, {String: `"`, Valid: true}}, + [3]sql.NullString{{String: ``, Valid: true}, {}, {}} + + if err := (GenericArray{&nsa}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nsa, expected) { + t.Errorf("Expected %v, got %v", expected, nsa) + } +} + +func TestGenericArrayScanScannerArrayString(t *testing.T) { + src, expected, nsa := `{NULL,"\"",xyz}`, + [3]sql.NullString{{}, {String: `"`, Valid: true}, {String: `xyz`, Valid: true}}, + [3]sql.NullString{{String: ``, Valid: true}, {}, {}} + + if err := (GenericArray{&nsa}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nsa, expected) { + t.Errorf("Expected %v, got %v", expected, nsa) + } +} + +func TestGenericArrayScanScannerSliceEmpty(t *testing.T) { + var nss []sql.NullString + + if err := (GenericArray{&nss}).Scan(`{}`); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if nss == nil || len(nss) != 0 { + t.Errorf("Expected empty, got %#v", nss) + } +} + +func TestGenericArrayScanScannerSliceNil(t *testing.T) { + nss := []sql.NullString{{String: ``, Valid: true}, {}} + + if err := (GenericArray{&nss}).Scan(nil); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if nss != nil { + t.Errorf("Expected nil, got %+v", nss) + } +} + +func TestGenericArrayScanScannerSliceBytes(t *testing.T) { + src, expected, nss := []byte(`{NULL,abc,"\""}`), + []sql.NullString{{}, {String: `abc`, Valid: true}, {String: `"`, Valid: true}}, + []sql.NullString{{String: ``, Valid: true}, {}, {}, {}, {}} + + if err := (GenericArray{&nss}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nss, expected) { + t.Errorf("Expected %v, got %v", expected, nss) + } +} + +func BenchmarkGenericArrayScanScannerSliceBytes(b *testing.B) { + var a GenericArray + var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`) + var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`) + + for i := 0; i < b.N; i++ { + a = GenericArray{new([]sql.NullString)} + a.Scan(x) + a = GenericArray{new([]sql.NullString)} + a.Scan(y) + } +} + +func TestGenericArrayScanScannerSliceString(t *testing.T) { + src, expected, nss := `{NULL,"\"",xyz}`, + []sql.NullString{{}, {String: `"`, Valid: true}, {String: `xyz`, Valid: true}}, + []sql.NullString{{String: ``, Valid: true}, {}, {}} + + if err := (GenericArray{&nss}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nss, expected) { + t.Errorf("Expected %v, got %v", expected, nss) + } +} + +type TildeNullInt64 struct{ sql.NullInt64 } + +func (TildeNullInt64) ArrayDelimiter() string { return "~" } + +func TestGenericArrayScanDelimiter(t *testing.T) { + src, expected, tnis := `{12~NULL~76}`, + []TildeNullInt64{{sql.NullInt64{Int64: 12, Valid: true}}, {}, {sql.NullInt64{Int64: 76, Valid: true}}}, + []TildeNullInt64{{sql.NullInt64{Int64: 0, Valid: true}}, {}} + + if err := (GenericArray{&tnis}).Scan(src); err != nil { + t.Fatalf("Expected no error for %#v, got %v", src, err) + } + if !reflect.DeepEqual(tnis, expected) { + t.Errorf("Expected %v for %#v, got %v", expected, src, tnis) + } +} + +func TestGenericArrayScanErrors(t *testing.T) { + var sa [1]string + var nis []sql.NullInt64 + var pss *[]string + + for _, tt := range []struct { + src, dest interface{} + err string + }{ + {nil, pss, "destination *[]string is nil"}, + {`{`, &sa, "unable to parse"}, + {`{}`, &sa, "cannot convert ARRAY[0] to [1]string"}, + {`{x,x}`, &sa, "cannot convert ARRAY[2] to [1]string"}, + {`{x}`, &nis, `parsing array element index 0: converting`}, + } { + err := GenericArray{tt.dest}.Scan(tt.src) + + if err == nil { + t.Fatalf("Expected error for [%#v %#v]", tt.src, tt.dest) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for [%#v %#v], got %q", tt.err, tt.src, tt.dest, err) + } + } +} + +func TestGenericArrayValueUnsupported(t *testing.T) { + _, err := GenericArray{true}.Value() + + if err == nil { + t.Fatal("Expected error for bool") + } + if !strings.Contains(err.Error(), "bool to array") { + t.Errorf("Expected type to be mentioned, got %q", err) + } +} + +type ByteArrayValuer [1]byte +type ByteSliceValuer []byte +type FuncArrayValuer struct { + delimiter func() string + value func() (driver.Value, error) +} + +func (a ByteArrayValuer) Value() (driver.Value, error) { return a[:], nil } +func (b ByteSliceValuer) Value() (driver.Value, error) { return []byte(b), nil } +func (f FuncArrayValuer) ArrayDelimiter() string { return f.delimiter() } +func (f FuncArrayValuer) Value() (driver.Value, error) { return f.value() } + +func TestGenericArrayValue(t *testing.T) { + result, err := GenericArray{nil}.Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + for _, tt := range []interface{}{ + []bool(nil), + [][]int(nil), + []*int(nil), + []sql.NullString(nil), + } { + result, err := GenericArray{tt}.Value() + + if err != nil { + t.Fatalf("Expected no error for %#v, got %v", tt, err) + } + if result != nil { + t.Errorf("Expected nil for %#v, got %q", tt, result) + } + } + + Tilde := func(v driver.Value) FuncArrayValuer { + return FuncArrayValuer{ + func() string { return "~" }, + func() (driver.Value, error) { return v, nil }} + } + + for _, tt := range []struct { + result string + input interface{} + }{ + {`{}`, []bool{}}, + {`{true}`, []bool{true}}, + {`{true,false}`, []bool{true, false}}, + {`{true,false}`, [2]bool{true, false}}, + + {`{}`, [][]int{{}}}, + {`{}`, [][]int{{}, {}}}, + {`{{1}}`, [][]int{{1}}}, + {`{{1},{2}}`, [][]int{{1}, {2}}}, + {`{{1,2},{3,4}}`, [][]int{{1, 2}, {3, 4}}}, + {`{{1,2},{3,4}}`, [2][2]int{{1, 2}, {3, 4}}}, + + {`{"a","\\b","c\"","d,e"}`, []string{`a`, `\b`, `c"`, `d,e`}}, + {`{"a","\\b","c\"","d,e"}`, [][]byte{{'a'}, {'\\', 'b'}, {'c', '"'}, {'d', ',', 'e'}}}, + + {`{NULL}`, []*int{nil}}, + {`{0,NULL}`, []*int{new(int), nil}}, + + {`{NULL}`, []sql.NullString{{}}}, + {`{"\"",NULL}`, []sql.NullString{{String: `"`, Valid: true}, {}}}, + + {`{"a","b"}`, []ByteArrayValuer{{'a'}, {'b'}}}, + {`{{"a","b"},{"c","d"}}`, [][]ByteArrayValuer{{{'a'}, {'b'}}, {{'c'}, {'d'}}}}, + + {`{"e","f"}`, []ByteSliceValuer{{'e'}, {'f'}}}, + {`{{"e","f"},{"g","h"}}`, [][]ByteSliceValuer{{{'e'}, {'f'}}, {{'g'}, {'h'}}}}, + + {`{1~2}`, []FuncArrayValuer{Tilde(int64(1)), Tilde(int64(2))}}, + {`{{1~2}~{3~4}}`, [][]FuncArrayValuer{{Tilde(int64(1)), Tilde(int64(2))}, {Tilde(int64(3)), Tilde(int64(4))}}}, + } { + result, err := GenericArray{tt.input}.Value() + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.input, err) + } + if !reflect.DeepEqual(result, tt.result) { + t.Errorf("Expected %q for %q, got %q", tt.result, tt.input, result) + } + } +} + +func TestGenericArrayValueErrors(t *testing.T) { + v := []interface{}{func() {}} + if _, err := (GenericArray{v}).Value(); err == nil { + t.Errorf("Expected error for %q, got nil", v) + } + + v = []interface{}{nil, func() {}} + if _, err := (GenericArray{v}).Value(); err == nil { + t.Errorf("Expected error for %q, got nil", v) + } +} + +func BenchmarkGenericArrayValueBools(b *testing.B) { + rand.Seed(1) + x := make([]bool, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Intn(2) == 0 + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueFloat64s(b *testing.B) { + rand.Seed(1) + x := make([]float64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.NormFloat64() + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueInt64s(b *testing.B) { + rand.Seed(1) + x := make([]int64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int63() + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueByteSlices(b *testing.B) { + x := make([][]byte, 10) + for i := 0; i < len(x); i++ { + x[i] = bytes.Repeat([]byte(`abc"def\ghi`), 5) + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueStrings(b *testing.B) { + x := make([]string, 10) + for i := 0; i < len(x); i++ { + x[i] = strings.Repeat(`abc"def\ghi`, 5) + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestArrayScanBackend(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + for _, tt := range []struct { + s string + d sql.Scanner + e interface{} + }{ + {`ARRAY[true, false]`, new(BoolArray), &BoolArray{true, false}}, + {`ARRAY[E'\\xdead', E'\\xbeef']`, new(ByteaArray), &ByteaArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}}, + {`ARRAY[1.2, 3.4]`, new(Float64Array), &Float64Array{1.2, 3.4}}, + {`ARRAY[1, 2, 3]`, new(Int64Array), &Int64Array{1, 2, 3}}, + {`ARRAY['a', E'\\b', 'c"', 'd,e']`, new(StringArray), &StringArray{`a`, `\b`, `c"`, `d,e`}}, + } { + err := db.QueryRow(`SELECT ` + tt.s).Scan(tt.d) + if err != nil { + t.Errorf("Expected no error when scanning %s into %T, got %v", tt.s, tt.d, err) + } + if !reflect.DeepEqual(tt.d, tt.e) { + t.Errorf("Expected %v when scanning %s into %T, got %v", tt.e, tt.s, tt.d, tt.d) + } + } +} + +func TestArrayValueBackend(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + for _, tt := range []struct { + s string + v driver.Valuer + }{ + {`ARRAY[true, false]`, BoolArray{true, false}}, + {`ARRAY[E'\\xdead', E'\\xbeef']`, ByteaArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}}, + {`ARRAY[1.2, 3.4]`, Float64Array{1.2, 3.4}}, + {`ARRAY[1, 2, 3]`, Int64Array{1, 2, 3}}, + {`ARRAY['a', E'\\b', 'c"', 'd,e']`, StringArray{`a`, `\b`, `c"`, `d,e`}}, + } { + var x int + err := db.QueryRow(`SELECT 1 WHERE `+tt.s+` <> $1`, tt.v).Scan(&x) + if err != sql.ErrNoRows { + t.Errorf("Expected %v to equal %s, got %v", tt.v, tt.s, err) + } + } +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/bench_test.go golang-github-lib-pq-1.3.0/bench_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/bench_test.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/bench_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -1,10 +1,9 @@ -// +build go1.1 - package pq import ( "bufio" "bytes" + "context" "database/sql" "database/sql/driver" "io" @@ -156,7 +155,7 @@ b.Fatal(err) } defer stmt.Close() - rows, err := stmt.Query(nil) + rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) if err != nil { b.Fatal(err) } @@ -266,7 +265,7 @@ } func benchPreparedMockQuery(b *testing.B, c *conn, stmt driver.Stmt) { - rows, err := stmt.Query(nil) + rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) if err != nil { b.Fatal(err) } diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/buf.go golang-github-lib-pq-1.3.0/buf.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/buf.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/buf.go 2019-12-11 04:41:10.000000000 +0000 @@ -66,7 +66,7 @@ } func (b *writeBuf) string(s string) { - b.buf = append(b.buf, (s + "\000")...) + b.buf = append(append(b.buf, s...), '\000') } func (b *writeBuf) byte(c byte) { diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/buf_test.go golang-github-lib-pq-1.3.0/buf_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/buf_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/buf_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,16 @@ +package pq + +import "testing" + +func Benchmark_writeBuf_string(b *testing.B) { + var buf writeBuf + const s = "foo" + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + buf.string(s) + buf.buf = buf.buf[:0] + } +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/certs/bogus_root.crt golang-github-lib-pq-1.3.0/certs/bogus_root.crt --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/certs/bogus_root.crt 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/certs/bogus_root.crt 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBjCCAe6gAwIBAgIQSnDYp/Naet9HOZljF5PuwDANBgkqhkiG9w0BAQsFADAr +MRIwEAYDVQQKEwlDb2Nrcm9hY2gxFTATBgNVBAMTDENvY2tyb2FjaCBDQTAeFw0x +NjAyMDcxNjQ0MzdaFw0xNzAyMDYxNjQ0MzdaMCsxEjAQBgNVBAoTCUNvY2tyb2Fj +aDEVMBMGA1UEAxMMQ29ja3JvYWNoIENBMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEAxdln3/UdgP7ayA/G1kT7upjLe4ERwQjYQ25q0e1+vgsB5jhiirxJ +e0+WkhhYu/mwoSAXzvlsbZ2PWFyfdanZeD/Lh6SvIeWXVVaPcWVWL1TEcoN2jr5+ +E85MMHmbbmaT2he8s6br2tM/UZxyTQ2XRprIzApbDssyw1c0Yufcpu3C6267FLEl +IfcWrzDhnluFhthhtGXv3ToD8IuMScMC5qlKBXtKmD1B5x14ngO/ecNJ+OlEi0HU +mavK4KWgI2rDXRZ2EnCpyTZdkc3kkRnzKcg653oOjMDRZdrhfIrha+Jq38ACsUmZ +Su7Sp5jkIHOCO8Zg+l6GKVSq37dKMapD8wIDAQABoyYwJDAOBgNVHQ8BAf8EBAMC +AuQwEgYDVR0TAQH/BAgwBgEB/wIBATANBgkqhkiG9w0BAQsFAAOCAQEAwZ2Tu0Yu +rrSVdMdoPEjT1IZd+5OhM/SLzL0ddtvTithRweLHsw2lDQYlXFqr24i3UGZJQ1sp +cqSrNwswgLUQT3vWyTjmM51HEb2vMYWKmjZ+sBQYAUP1CadrN/+OTfNGnlF1+B4w +IXOzh7EvQmJJnNybLe4a/aRvj1NE2n8Z898B76SVU9WbfKKz8VwLzuIPDqkKcZda +lMy5yzthyztV9YjcWs2zVOUGZvGdAhDrvZuUq6mSmxrBEvR2LBOggmVf3tGRT+Ls +lW7c9Lrva5zLHuqmoPP07A+vuI9a0D1X44jwGDuPWJ5RnTOQ63Uez12mKNjqleHw +DnkwNanuO8dhAA== +-----END CERTIFICATE----- diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/connector_example_test.go golang-github-lib-pq-1.3.0/connector_example_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/connector_example_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/connector_example_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,29 @@ +// +build go1.10 + +package pq_test + +import ( + "database/sql" + "fmt" + + "github.com/lib/pq" +) + +func ExampleNewConnector() { + name := "" + connector, err := pq.NewConnector(name) + if err != nil { + fmt.Println(err) + return + } + db := sql.OpenDB(connector) + defer db.Close() + + // Use the DB + txn, err := db.Begin() + if err != nil { + fmt.Println(err) + return + } + txn.Rollback() +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/connector.go golang-github-lib-pq-1.3.0/connector.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/connector.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/connector.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,110 @@ +package pq + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "os" + "strings" +) + +// Connector represents a fixed configuration for the pq driver with a given +// name. Connector satisfies the database/sql/driver Connector interface and +// can be used to create any number of DB Conn's via the database/sql OpenDB +// function. +// +// See https://golang.org/pkg/database/sql/driver/#Connector. +// See https://golang.org/pkg/database/sql/#OpenDB. +type Connector struct { + opts values + dialer Dialer +} + +// Connect returns a connection to the database using the fixed configuration +// of this Connector. Context is not used. +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { + return c.open(ctx) +} + +// Driver returnst the underlying driver of this Connector. +func (c *Connector) Driver() driver.Driver { + return &Driver{} +} + +// NewConnector returns a connector for the pq driver in a fixed configuration +// with the given dsn. The returned connector can be used to create any number +// of equivalent Conn's. The returned connector is intended to be used with +// database/sql.OpenDB. +// +// See https://golang.org/pkg/database/sql/driver/#Connector. +// See https://golang.org/pkg/database/sql/#OpenDB. +func NewConnector(dsn string) (*Connector, error) { + var err error + o := make(values) + + // A number of defaults are applied here, in this order: + // + // * Very low precedence defaults applied in every situation + // * Environment variables + // * Explicitly passed connection information + o["host"] = "localhost" + o["port"] = "5432" + // N.B.: Extra float digits should be set to 3, but that breaks + // Postgres 8.4 and older, where the max is 2. + o["extra_float_digits"] = "2" + for k, v := range parseEnviron(os.Environ()) { + o[k] = v + } + + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { + dsn, err = ParseURL(dsn) + if err != nil { + return nil, err + } + } + + if err := parseOpts(dsn, o); err != nil { + return nil, err + } + + // Use the "fallback" application name if necessary + if fallback, ok := o["fallback_application_name"]; ok { + if _, ok := o["application_name"]; !ok { + o["application_name"] = fallback + } + } + + // We can't work with any client_encoding other than UTF-8 currently. + // However, we have historically allowed the user to set it to UTF-8 + // explicitly, and there's no reason to break such programs, so allow that. + // Note that the "options" setting could also set client_encoding, but + // parsing its value is not worth it. Instead, we always explicitly send + // client_encoding as a separate run-time parameter, which should override + // anything set in options. + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { + return nil, errors.New("client_encoding must be absent or 'UTF8'") + } + o["client_encoding"] = "UTF8" + // DateStyle needs a similar treatment. + if datestyle, ok := o["datestyle"]; ok { + if datestyle != "ISO, MDY" { + return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle) + } + } else { + o["datestyle"] = "ISO, MDY" + } + + // If a user is not provided by any other means, the last + // resort is to use the current operating system provided user + // name. + if _, ok := o["user"]; !ok { + u, err := userCurrent() + if err != nil { + return nil, err + } + o["user"] = u + } + + return &Connector{opts: o, dialer: defaultDialer{}}, nil +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/connector_test.go golang-github-lib-pq-1.3.0/connector_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/connector_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/connector_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,67 @@ +// +build go1.10 + +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "testing" +) + +func TestNewConnector_WorksWithOpenDB(t *testing.T) { + name := "" + c, err := NewConnector(name) + if err != nil { + t.Fatal(err) + } + db := sql.OpenDB(c) + defer db.Close() + // database/sql might not call our Open at all unless we do something with + // the connection + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + txn.Rollback() +} + +func TestNewConnector_Connect(t *testing.T) { + name := "" + c, err := NewConnector(name) + if err != nil { + t.Fatal(err) + } + db, err := c.Connect(context.Background()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + // database/sql might not call our Open at all unless we do something with + // the connection + txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) + if err != nil { + t.Fatal(err) + } + txn.Rollback() +} + +func TestNewConnector_Driver(t *testing.T) { + name := "" + c, err := NewConnector(name) + if err != nil { + t.Fatal(err) + } + db, err := c.Driver().Open(name) + if err != nil { + t.Fatal(err) + } + defer db.Close() + // database/sql might not call our Open at all unless we do something with + // the connection + txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) + if err != nil { + t.Fatal(err) + } + txn.Rollback() +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/conn.go golang-github-lib-pq-1.3.0/conn.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/conn.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/conn.go 2019-12-11 04:41:10.000000000 +0000 @@ -2,16 +2,15 @@ import ( "bufio" + "context" "crypto/md5" - "crypto/tls" - "crypto/x509" + "crypto/sha256" "database/sql" "database/sql/driver" "encoding/binary" "errors" "fmt" "io" - "io/ioutil" "net" "os" "os/user" @@ -23,6 +22,7 @@ "unicode" "github.com/lib/pq/oid" + "github.com/lib/pq/scram" ) // Common error types @@ -30,18 +30,26 @@ ErrNotSupported = errors.New("pq: Unsupported command") ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction") ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server") - ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.") - ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.") + ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less") + ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly") + + errUnexpectedReady = errors.New("unexpected ReadyForQuery") + errNoRowsAffected = errors.New("no RowsAffected available after the empty statement") + errNoLastInsertID = errors.New("no LastInsertId available after the empty statement") ) -type drv struct{} +// Driver is the Postgres database driver. +type Driver struct{} -func (d *drv) Open(name string) (driver.Conn, error) { +// Open opens a new connection to the database. name is a connection string. +// Most users should only use it through database/sql package from the standard +// library. +func (d *Driver) Open(name string) (driver.Conn, error) { return Open(name) } func init() { - sql.Register("postgres", &drv{}) + sql.Register("postgres", &Driver{}) } type parameterStatus struct { @@ -77,18 +85,32 @@ panic("not reached") } +// Dialer is the dialer interface. It can be used to obtain more control over +// how pq creates network connections. type Dialer interface { Dial(network, address string) (net.Conn, error) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) } -type defaultDialer struct{} +// DialerContext is the context-aware dialer interface. +type DialerContext interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +type defaultDialer struct { + d net.Dialer +} -func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) { - return net.Dial(ntw, addr) +func (d defaultDialer) Dial(network, address string) (net.Conn, error) { + return d.d.Dial(network, address) } -func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout(ntw, addr, timeout) +func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return d.DialContext(ctx, network, address) +} +func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.d.DialContext(ctx, network, address) } type conn struct { @@ -97,6 +119,15 @@ namei int scratch [512]byte txnStatus transactionStatus + txnFinish func() + + // Save connection arguments to use during CancelRequest. + dialer Dialer + opts values + + // Cancellation key data for use with CancelRequest messages. + processID int + secretKey int parameterStatus parameterStatus @@ -115,12 +146,15 @@ // Whether to always send []byte parameters over as binary. Enables single // round-trip mode for non-prepared Query calls. binaryParameters bool + + // If true this connection is in the middle of a COPY + inCopy bool } // Handle driver-side settings in parsed connection string. -func (c *conn) handleDriverSettings(o values) (err error) { +func (cn *conn) handleDriverSettings(o values) (err error) { boolSetting := func(key string, val *bool) error { - if value := o.Get(key); value != "" { + if value, ok := o[key]; ok { if value == "yes" { *val = true } else if value == "no" { @@ -132,179 +166,220 @@ return nil } - err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult) + err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) if err != nil { return err } - err = boolSetting("binary_parameters", &c.binaryParameters) + return boolSetting("binary_parameters", &cn.binaryParameters) +} + +func (cn *conn) handlePgpass(o values) { + // if a password was supplied, do not process .pgpass + if _, ok := o["password"]; ok { + return + } + filename := os.Getenv("PGPASSFILE") + if filename == "" { + // XXX this code doesn't work on Windows where the default filename is + // XXX %APPDATA%\postgresql\pgpass.conf + // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470 + userHome := os.Getenv("HOME") + if userHome == "" { + user, err := user.Current() + if err != nil { + return + } + userHome = user.HomeDir + } + filename = filepath.Join(userHome, ".pgpass") + } + fileinfo, err := os.Stat(filename) if err != nil { - return err + return + } + mode := fileinfo.Mode() + if mode&(0x77) != 0 { + // XXX should warn about incorrect .pgpass permissions as psql does + return + } + file, err := os.Open(filename) + if err != nil { + return + } + defer file.Close() + scanner := bufio.NewScanner(io.Reader(file)) + hostname := o["host"] + ntw, _ := network(o) + port := o["port"] + db := o["dbname"] + username := o["user"] + // From: https://github.com/tg/pgpass/blob/master/reader.go + getFields := func(s string) []string { + fs := make([]string, 0, 5) + f := make([]rune, 0, len(s)) + + var esc bool + for _, c := range s { + switch { + case esc: + f = append(f, c) + esc = false + case c == '\\': + esc = true + case c == ':': + fs = append(fs, string(f)) + f = f[:0] + default: + f = append(f, c) + } + } + return append(fs, string(f)) + } + for scanner.Scan() { + line := scanner.Text() + if len(line) == 0 || line[0] == '#' { + continue + } + split := getFields(line) + if len(split) != 5 { + continue + } + if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) { + o["password"] = split[4] + return + } } - return nil } -func (c *conn) writeBuf(b byte) *writeBuf { - c.scratch[0] = b +func (cn *conn) writeBuf(b byte) *writeBuf { + cn.scratch[0] = b return &writeBuf{ - buf: c.scratch[:5], + buf: cn.scratch[:5], pos: 1, } } -func Open(name string) (_ driver.Conn, err error) { - return DialOpen(defaultDialer{}, name) +// Open opens a new connection to the database. dsn is a connection string. +// Most users should only use it through database/sql package from the standard +// library. +func Open(dsn string) (_ driver.Conn, err error) { + return DialOpen(defaultDialer{}, dsn) +} + +// DialOpen opens a new connection to the database using a dialer. +func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { + c, err := NewConnector(dsn) + if err != nil { + return nil, err + } + c.dialer = d + return c.open(context.Background()) } -func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { +func (c *Connector) open(ctx context.Context) (cn *conn, err error) { // Handle any panics during connection initialization. Note that we // specifically do *not* want to use errRecover(), as that would turn any // connection errors into ErrBadConns, hiding the real error message from // the user. defer errRecoverNoErrBadConn(&err) - o := make(values) - - // A number of defaults are applied here, in this order: - // - // * Very low precedence defaults applied in every situation - // * Environment variables - // * Explicitly passed connection information - o.Set("host", "localhost") - o.Set("port", "5432") - // N.B.: Extra float digits should be set to 3, but that breaks - // Postgres 8.4 and older, where the max is 2. - o.Set("extra_float_digits", "2") - for k, v := range parseEnviron(os.Environ()) { - o.Set(k, v) - } + o := c.opts - if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { - name, err = ParseURL(name) - if err != nil { - return nil, err - } + cn = &conn{ + opts: o, + dialer: c.dialer, } - - if err := parseOpts(name, o); err != nil { + err = cn.handleDriverSettings(o) + if err != nil { return nil, err } + cn.handlePgpass(o) - // Use the "fallback" application name if necessary - if fallback := o.Get("fallback_application_name"); fallback != "" { - if !o.Isset("application_name") { - o.Set("application_name", fallback) - } - } - - // We can't work with any client_encoding other than UTF-8 currently. - // However, we have historically allowed the user to set it to UTF-8 - // explicitly, and there's no reason to break such programs, so allow that. - // Note that the "options" setting could also set client_encoding, but - // parsing its value is not worth it. Instead, we always explicitly send - // client_encoding as a separate run-time parameter, which should override - // anything set in options. - if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) { - return nil, errors.New("client_encoding must be absent or 'UTF8'") - } - o.Set("client_encoding", "UTF8") - // DateStyle needs a similar treatment. - if datestyle := o.Get("datestyle"); datestyle != "" { - if datestyle != "ISO, MDY" { - panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", - "ISO, MDY", datestyle)) - } - } else { - o.Set("datestyle", "ISO, MDY") - } - - // If a user is not provided by any other means, the last - // resort is to use the current operating system provided user - // name. - if o.Get("user") == "" { - u, err := userCurrent() - if err != nil { - return nil, err - } else { - o.Set("user", u) - } - } - - cn := &conn{} - err = cn.handleDriverSettings(o) + cn.c, err = dial(ctx, c.dialer, o) if err != nil { return nil, err } - cn.c, err = dial(d, o) + err = cn.ssl(o) if err != nil { + if cn.c != nil { + cn.c.Close() + } return nil, err } - cn.ssl(o) + + // cn.startup panics on error. Make sure we don't leak cn.c. + panicking := true + defer func() { + if panicking { + cn.c.Close() + } + }() + cn.buf = bufio.NewReader(cn.c) cn.startup(o) // reset the deadline, in case one was set (see dial) - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { err = cn.c.SetDeadline(time.Time{}) } + panicking = false return cn, err } -func dial(d Dialer, o values) (net.Conn, error) { - ntw, addr := network(o) +func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { + network, address := network(o) // SSL is not necessary or supported over UNIX domain sockets - if ntw == "unix" { + if network == "unix" { o["sslmode"] = "disable" } // Zero or not specified means wait indefinitely. - if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" { + if timeout, ok := o["connect_timeout"]; ok && timeout != "0" { seconds, err := strconv.ParseInt(timeout, 10, 0) if err != nil { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) } duration := time.Duration(seconds) * time.Second + // connect_timeout should apply to the entire connection establishment // procedure, so we both use a timeout for the TCP connection // establishment and set a deadline for doing the initial handshake. // The deadline is then reset after startup() is done. deadline := time.Now().Add(duration) - conn, err := d.DialTimeout(ntw, addr, duration) + var conn net.Conn + if dctx, ok := d.(DialerContext); ok { + ctx, cancel := context.WithTimeout(ctx, duration) + defer cancel() + conn, err = dctx.DialContext(ctx, network, address) + } else { + conn, err = d.DialTimeout(network, address, duration) + } if err != nil { return nil, err } err = conn.SetDeadline(deadline) return conn, err } - return d.Dial(ntw, addr) + if dctx, ok := d.(DialerContext); ok { + return dctx.DialContext(ctx, network, address) + } + return d.Dial(network, address) } func network(o values) (string, string) { - host := o.Get("host") + host := o["host"] if strings.HasPrefix(host, "/") { - sockPath := path.Join(host, ".s.PGSQL."+o.Get("port")) + sockPath := path.Join(host, ".s.PGSQL."+o["port"]) return "unix", sockPath } - return "tcp", host + ":" + o.Get("port") + return "tcp", net.JoinHostPort(host, o["port"]) } type values map[string]string -func (vs values) Set(k, v string) { - vs[k] = v -} - -func (vs values) Get(k string) (v string) { - return vs[k] -} - -func (vs values) Isset(k string) bool { - _, ok := vs[k] - return ok -} - // scanner implements a tokenizer for libpq-style option strings. type scanner struct { s []rune @@ -375,7 +450,7 @@ // Skip any whitespace after the = if r, ok = s.SkipSpaces(); !ok { // If we reach the end here, the last value is just an empty string as per libpq. - o.Set(string(keyRunes), "") + o[string(keyRunes)] = "" break } @@ -410,7 +485,7 @@ } } - o.Set(string(keyRunes), string(valRunes)) + o[string(keyRunes)] = string(valRunes) } return nil @@ -429,13 +504,17 @@ } func (cn *conn) Begin() (_ driver.Tx, err error) { + return cn.begin("") +} + +func (cn *conn) begin(mode string) (_ driver.Tx, err error) { if cn.bad { return nil, driver.ErrBadConn } defer cn.errRecover(&err) cn.checkIsInTransaction(false) - _, commandTag, err := cn.simpleExec("BEGIN") + _, commandTag, err := cn.simpleExec("BEGIN" + mode) if err != nil { return nil, err } @@ -450,7 +529,14 @@ return cn, nil } +func (cn *conn) closeTxn() { + if finish := cn.txnFinish; finish != nil { + finish() + } +} + func (cn *conn) Commit() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } @@ -464,7 +550,7 @@ // would get the same behaviour if you issued a COMMIT in a failed // transaction, so it's also the least surprising thing to do here. if cn.txnStatus == txnStatusInFailedTransaction { - if err := cn.Rollback(); err != nil { + if err := cn.rollback(); err != nil { return err } return ErrInFailedTransaction @@ -486,11 +572,15 @@ } func (cn *conn) Rollback() (err error) { + defer cn.closeTxn() if cn.bad { return driver.ErrBadConn } defer cn.errRecover(&err) + return cn.rollback() +} +func (cn *conn) rollback() (err error) { cn.checkIsInTransaction(true) _, commandTag, err := cn.simpleExec("ROLLBACK") if err != nil { @@ -523,11 +613,16 @@ res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } // done return case 'E': err = parseError(r) - case 'T', 'D', 'I': + case 'I': + res = emptyRows + case 'T', 'D': // ignore any results default: cn.bad = true @@ -539,8 +634,6 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { defer cn.errRecover(&err) - st := &stmt{cn: cn, name: ""} - b := cn.writeBuf('Q') b.string(q) cn.send(b) @@ -557,13 +650,18 @@ cn.bad = true errorf("unexpected message %q in simple query execution", t) } - res = &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, - done: true, + if res == nil { + res = &rows{ + cn: cn, + } } + // Set the result and tag to the last command complete if there wasn't a + // query already run. Although queries usually return from here and cede + // control to Next, a query with zero results does not. + if t == 'C' && res.colNames == nil { + res.result, res.tag = cn.parseComplete(r.string()) + } + res.done = true case 'Z': cn.processReadyForQuery(r) // done @@ -583,7 +681,7 @@ // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it res = &rows{cn: cn} - res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) + res.rowsHeader = parsePortalRowDescribe(r) // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. @@ -594,9 +692,23 @@ } } +type noRows struct{} + +var emptyRows noRows + +var _ driver.Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errNoLastInsertID +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errNoRowsAffected +} + // Decides which column formats to use for a prepared statement. The input is // an array of type oids, one element per result column. -func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) { +func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) { if len(colTyps) == 0 { return nil, colFmtDataAllText } @@ -608,8 +720,8 @@ allBinary := true allText := true - for i, o := range colTyps { - switch o { + for i, t := range colTyps { + switch t.OID { // This is the list of types to use binary mode for when receiving them // through a prepared statement. If a type appears in this list, it // must also be implemented in binaryDecode in encode.go. @@ -620,6 +732,8 @@ case oid.T_int4: fallthrough case oid.T_int2: + fallthrough + case oid.T_uuid: colFmts[i] = formatBinary allText = false @@ -671,32 +785,45 @@ defer cn.errRecover(&err) if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") { - return cn.prepareCopyIn(q) + s, err := cn.prepareCopyIn(q) + if err == nil { + cn.inCopy = true + } + return s, err } return cn.prepareTo(q, cn.gname()), nil } func (cn *conn) Close() (err error) { - if cn.bad { - return driver.ErrBadConn - } + // Skip cn.bad return here because we always want to close a connection. defer cn.errRecover(&err) + // Ensure that cn.c.Close is always run. Since error handling is done with + // panics and cn.errRecover, the Close must be in a defer. + defer func() { + cerr := cn.c.Close() + if err == nil { + err = cerr + } + }() + // Don't go through send(); ListenerConn relies on us not scribbling on the // scratch buffer of this connection. - err = cn.sendSimpleMessage('X') - if err != nil { - return err - } - - return cn.c.Close() + return cn.sendSimpleMessage('X') } // Implement the "Queryer" interface -func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) { +func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return cn.query(query, args) +} + +func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { if cn.bad { return nil, driver.ErrBadConn } + if cn.inCopy { + return nil, errCopyInProgress + } defer cn.errRecover(&err) // Check to see if we can use the "simpleQuery" interface, which is @@ -711,19 +838,16 @@ cn.readParseResponse() cn.readBindResponse() rows := &rows{cn: cn} - rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() + rows.rowsHeader = cn.readPortalDescribeResponse() cn.postExecuteWorkaround() return rows, nil - } else { - st := cn.prepareTo(query, "") - st.exec(args) - return &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, - }, nil } + st := cn.prepareTo(query, "") + st.exec(args) + return &rows{ + cn: cn, + rowsHeader: st.rowsHeader, + }, nil } // Implement the optional "Execer" interface for one-shot queries @@ -750,17 +874,16 @@ cn.postExecuteWorkaround() res, _, err = cn.readExecuteResponse("Execute") return res, err - } else { - // Use the unnamed statement to defer planning until bind - // time, or else value-based selectivity estimates cannot be - // used. - st := cn.prepareTo(query, "") - r, err := st.Exec(args) - if err != nil { - panic(err) - } - return r, err } + // Use the unnamed statement to defer planning until bind + // time, or else value-based selectivity estimates cannot be + // used. + st := cn.prepareTo(query, "") + r, err := st.Exec(args) + if err != nil { + panic(err) + } + return r, err } func (cn *conn) send(m *writeBuf) { @@ -770,16 +893,9 @@ } } -func (cn *conn) sendStartupPacket(m *writeBuf) { - // sanity check - if m.buf[0] != 0 { - panic("oops") - } - +func (cn *conn) sendStartupPacket(m *writeBuf) error { _, err := cn.c.Write((m.wrap())[1:]) - if err != nil { - panic(err) - } + return err } // Send a message of type typ to the server on the other end of cn. The @@ -851,7 +967,6 @@ if err != nil { panic(err) } - switch t { case 'E': panic(parseError(r)) @@ -892,150 +1007,35 @@ return t, r } -func (cn *conn) ssl(o values) { - verifyCaOnly := false - tlsConf := tls.Config{} - switch mode := o.Get("sslmode"); mode { - case "require", "": - tlsConf.InsecureSkipVerify = true - case "verify-ca": - // We must skip TLS's own verification since it requires full - // verification since Go 1.3. - tlsConf.InsecureSkipVerify = true - verifyCaOnly = true - case "verify-full": - tlsConf.ServerName = o.Get("host") - case "disable": - return - default: - errorf(`unsupported sslmode %q; only "require" (default), "verify-full", and "disable" supported`, mode) +func (cn *conn) ssl(o values) error { + upgrade, err := ssl(o) + if err != nil { + return err } - cn.setupSSLClientCertificates(&tlsConf, o) - cn.setupSSLCA(&tlsConf, o) + if upgrade == nil { + // Nothing to do + return nil + } w := cn.writeBuf(0) w.int32(80877103) - cn.sendStartupPacket(w) + if err = cn.sendStartupPacket(w); err != nil { + return err + } b := cn.scratch[:1] - _, err := io.ReadFull(cn.c, b) + _, err = io.ReadFull(cn.c, b) if err != nil { - panic(err) + return err } if b[0] != 'S' { - panic(ErrSSLNotSupported) - } - - client := tls.Client(cn.c, &tlsConf) - if verifyCaOnly { - cn.verifyCA(client, &tlsConf) - } - cn.c = client -} - -// verifyCA carries out a TLS handshake to the server and verifies the -// presented certificate against the effective CA, i.e. the one specified in -// sslrootcert or the system CA if sslrootcert was not specified. -func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) { - err := client.Handshake() - if err != nil { - panic(err) - } - certs := client.ConnectionState().PeerCertificates - opts := x509.VerifyOptions{ - DNSName: client.ConnectionState().ServerName, - Intermediates: x509.NewCertPool(), - Roots: tlsConf.RootCAs, - } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) - } - _, err = certs[0].Verify(opts) - if err != nil { - panic(err) - } -} - -// This function sets up SSL client certificates based on either the "sslkey" -// and "sslcert" settings (possibly set via the environment variables PGSSLKEY -// and PGSSLCERT, respectively), or if they aren't set, from the .postgresql -// directory in the user's home directory. If the file paths are set -// explicitly, the files must exist. The key file must also not be -// world-readable, or this function will panic with -// ErrSSLKeyHasWorldPermissions. -func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) { - var missingOk bool - - sslkey := o.Get("sslkey") - sslcert := o.Get("sslcert") - if sslkey != "" && sslcert != "" { - // If the user has set an sslkey and sslcert, they *must* exist. - missingOk = false - } else { - // Automatically load certificates from ~/.postgresql. - user, err := user.Current() - if err != nil { - // user.Current() might fail when cross-compiling. We have to - // ignore the error and continue without client certificates, since - // we wouldn't know where to load them from. - return - } - - sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") - sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") - missingOk = true - } - - // Check that both files exist, and report the error or stop, depending on - // which behaviour we want. Note that we don't do any more extensive - // checks than this (such as checking that the paths aren't directories); - // LoadX509KeyPair() will take care of the rest. - keyfinfo, err := os.Stat(sslkey) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) - } - _, err = os.Stat(sslcert) - if err != nil && missingOk { - return - } else if err != nil { - panic(err) + return ErrSSLNotSupported } - // If we got this far, the key file must also have the correct permissions - kmode := keyfinfo.Mode() - if kmode != kmode&0600 { - panic(ErrSSLKeyHasWorldPermissions) - } - - cert, err := tls.LoadX509KeyPair(sslcert, sslkey) - if err != nil { - panic(err) - } - tlsConf.Certificates = []tls.Certificate{cert} -} - -// Sets up RootCAs in the TLS configuration if sslrootcert is set. -func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) { - if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" { - tlsConf.RootCAs = x509.NewCertPool() - - cert, err := ioutil.ReadFile(sslrootcert) - if err != nil { - panic(err) - } - - ok := tlsConf.RootCAs.AppendCertsFromPEM(cert) - if !ok { - errorf("couldn't parse pem in sslrootcert") - } - } + cn.c, err = upgrade(cn.c) + return err } // isDriverSetting returns true iff a setting is purely for configuring the @@ -1084,12 +1084,15 @@ w.string(v) } w.string("") - cn.sendStartupPacket(w) + if err := cn.sendStartupPacket(w); err != nil { + panic(err) + } for { t, r := cn.recv() switch t { case 'K': + cn.processBackendKeyData(r) case 'S': cn.processParameterStatus(r) case 'R': @@ -1109,7 +1112,7 @@ // OK case 3: w := cn.writeBuf('p') - w.string(o.Get("password")) + w.string(o["password"]) cn.send(w) t, r := cn.recv() @@ -1123,7 +1126,7 @@ case 5: s := string(r.next(4)) w := cn.writeBuf('p') - w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s)) + w.string("md5" + md5s(md5s(o["password"]+o["user"])+s)) cn.send(w) t, r := cn.recv() @@ -1134,6 +1137,55 @@ if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } + case 10: + sc := scram.NewClient(sha256.New, o["user"], o["password"]) + sc.Step(nil) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + scOut := sc.Out() + + w := cn.writeBuf('p') + w.string("SCRAM-SHA-256") + w.int32(len(scOut)) + w.bytes(scOut) + cn.send(w) + + t, r := cn.recv() + if t != 'R' { + errorf("unexpected password response: %q", t) + } + + if r.int32() != 11 { + errorf("unexpected authentication response: %q", t) + } + + nextStep := r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + + scOut = sc.Out() + w = cn.writeBuf('p') + w.bytes(scOut) + cn.send(w) + + t, r = cn.recv() + if t != 'R' { + errorf("unexpected password response: %q", t) + } + + if r.int32() != 12 { + errorf("unexpected authentication response: %q", t) + } + + nextStep = r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + default: errorf("unknown authentication response: %d", code) } @@ -1145,18 +1197,16 @@ const formatBinary format = 1 // One result-column format code with the value 1 (i.e. all binary). -var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1} +var colFmtDataAllBinary = []byte{0, 1, 0, 1} // No result-column format codes (i.e. all text). -var colFmtDataAllText []byte = []byte{0, 0} +var colFmtDataAllText = []byte{0, 0} type stmt struct { - cn *conn - name string - colNames []string - colFmts []format + cn *conn + name string + rowsHeader colFmtData []byte - colTyps []oid.Oid paramTyps []oid.Oid closed bool } @@ -1202,10 +1252,8 @@ st.exec(v) return &rows{ - cn: st.cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, + cn: st.cn, + rowsHeader: st.rowsHeader, }, nil } @@ -1315,23 +1363,40 @@ return driver.RowsAffected(n), commandTag } -type rows struct { - cn *conn +type rowsHeader struct { colNames []string - colTyps []oid.Oid + colTyps []fieldDesc colFmts []format - done bool - rb readBuf +} + +type rows struct { + cn *conn + finish func() + rowsHeader + done bool + rb readBuf + result driver.Result + tag string + + next *rowsHeader } func (rs *rows) Close() error { + if finish := rs.finish; finish != nil { + defer finish() + } // no need to look at cn.bad as Next() will for { err := rs.Next(nil) switch err { case nil: case io.EOF: - return nil + // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row + // description, used with HasNextResultSet). We need to fetch messages until + // we hit a 'Z', which is done by waiting for done to be set. + if rs.done { + return nil + } default: return err } @@ -1342,6 +1407,17 @@ return rs.colNames } +func (rs *rows) Result() driver.Result { + if rs.result == nil { + return emptyRows + } + return rs.result +} + +func (rs *rows) Tag() string { + return rs.tag +} + func (rs *rows) Next(dest []driver.Value) (err error) { if rs.done { return io.EOF @@ -1359,6 +1435,9 @@ case 'E': err = parseError(&rs.rb) case 'C', 'I': + if t == 'C' { + rs.result, rs.tag = conn.parseComplete(rs.rb.string()) + } continue case 'Z': conn.processReadyForQuery(&rs.rb) @@ -1382,21 +1461,40 @@ dest[i] = nil continue } - dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i]) + dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i]) } return + case 'T': + next := parsePortalRowDescribe(&rs.rb) + rs.next = &next + return io.EOF default: errorf("unexpected message after execute: %q", t) } } } +func (rs *rows) HasNextResultSet() bool { + hasNext := rs.next != nil && !rs.done + return hasNext +} + +func (rs *rows) NextResultSet() error { + if rs.next == nil { + return io.EOF + } + rs.rowsHeader = *rs.next + rs.next = nil + return nil +} + // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be // used as part of an SQL statement. For example: // // tblname := "my_table" // data := "my_data" -// err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data) +// quoted := pq.QuoteIdentifier(tblname) +// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data) // // Any double quotes in name will be escaped. The quoted identifier will be // case sensitive when used in a query. If the input string contains a zero @@ -1409,6 +1507,39 @@ return `"` + strings.Replace(name, `"`, `""`, -1) + `"` } +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` + } + return literal +} + func md5s(s string) string { h := md5.New() h.Write([]byte(s)) @@ -1477,7 +1608,7 @@ cn.send(b) } -func (c *conn) processParameterStatus(r *readBuf) { +func (cn *conn) processParameterStatus(r *readBuf) { var err error param := r.string() @@ -1488,13 +1619,13 @@ var minor int _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor) if err == nil { - c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor + cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor } case "TimeZone": - c.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) + cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string()) if err != nil { - c.parameterStatus.currentLocation = nil + cn.parameterStatus.currentLocation = nil } default: @@ -1502,8 +1633,8 @@ } } -func (c *conn) processReadyForQuery(r *readBuf) { - c.txnStatus = transactionStatus(r.byte()) +func (cn *conn) processReadyForQuery(r *readBuf) { + cn.txnStatus = transactionStatus(r.byte()) } func (cn *conn) readReadyForQuery() { @@ -1518,6 +1649,11 @@ } } +func (cn *conn) processBackendKeyData(r *readBuf) { + cn.processID = r.int32() + cn.secretKey = r.int32() +} + func (cn *conn) readParseResponse() { t, r := cn.recv1() switch t { @@ -1533,7 +1669,7 @@ } } -func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) { +func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) { for { t, r := cn.recv1() switch t { @@ -1559,13 +1695,13 @@ } } -func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) { +func (cn *conn) readPortalDescribeResponse() rowsHeader { t, r := cn.recv1() switch t { case 'T': return parsePortalRowDescribe(r) case 'n': - return nil, nil, nil + return rowsHeader{} case 'E': err := parseError(r) cn.readReadyForQuery() @@ -1633,6 +1769,9 @@ res, commandTag = cn.parseComplete(r.string()) case 'Z': cn.processReadyForQuery(r) + if res == nil && err == nil { + err = errUnexpectedReady + } return res, commandTag, err case 'E': err = parseError(r) @@ -1641,6 +1780,9 @@ cn.bad = true errorf("unexpected %q after error %s", t, err) } + if t == 'I' { + res = emptyRows + } // ignore any results default: cn.bad = true @@ -1649,34 +1791,40 @@ } } -func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) { +func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) { n := r.int16() colNames = make([]string, n) - colTyps = make([]oid.Oid, n) + colTyps = make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() // format code not known when describing a statement; always 0 r.next(2) } return } -func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) { +func parsePortalRowDescribe(r *readBuf) rowsHeader { n := r.int16() - colNames = make([]string, n) - colFmts = make([]format, n) - colTyps = make([]oid.Oid, n) + colNames := make([]string, n) + colFmts := make([]format, n) + colTyps := make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) - colTyps[i] = r.oid() - r.next(6) + colTyps[i].OID = r.oid() + colTyps[i].Len = r.int16() + colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } - return + return rowsHeader{ + colNames: colNames, + colFmts: colFmts, + colTyps: colTyps, + } } // parseEnviron tries to mimic some of libpq's environment handling @@ -1719,7 +1867,7 @@ accrue("user") case "PGPASSWORD": accrue("password") - case "PGPASSFILE", "PGSERVICE", "PGSERVICEFILE", "PGREALM": + case "PGSERVICE", "PGSERVICEFILE", "PGREALM": unsupported() case "PGOPTIONS": accrue("options") diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/conn_go18.go golang-github-lib-pq-1.3.0/conn_go18.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/conn_go18.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/conn_go18.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,149 @@ +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "io/ioutil" + "time" +) + +// Implement the "QueryerContext" interface +func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + finish := cn.watchCancel(ctx) + r, err := cn.query(query, list) + if err != nil { + if finish != nil { + finish() + } + return nil, err + } + r.finish = finish + return r, nil +} + +// Implement the "ExecerContext" interface +func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + list := make([]driver.Value, len(args)) + for i, nv := range args { + list[i] = nv.Value + } + + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + + return cn.Exec(query, list) +} + +// Implement the "ConnBeginTx" interface +func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + var mode string + + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + // Don't touch mode: use the server's default + case sql.LevelReadUncommitted: + mode = " ISOLATION LEVEL READ UNCOMMITTED" + case sql.LevelReadCommitted: + mode = " ISOLATION LEVEL READ COMMITTED" + case sql.LevelRepeatableRead: + mode = " ISOLATION LEVEL REPEATABLE READ" + case sql.LevelSerializable: + mode = " ISOLATION LEVEL SERIALIZABLE" + default: + return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation) + } + + if opts.ReadOnly { + mode += " READ ONLY" + } else { + mode += " READ WRITE" + } + + tx, err := cn.begin(mode) + if err != nil { + return nil, err + } + cn.txnFinish = cn.watchCancel(ctx) + return tx, nil +} + +func (cn *conn) Ping(ctx context.Context) error { + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + rows, err := cn.simpleQuery(";") + if err != nil { + return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger + } + rows.Close() + return nil +} + +func (cn *conn) watchCancel(ctx context.Context) func() { + if done := ctx.Done(); done != nil { + finished := make(chan struct{}) + go func() { + select { + case <-done: + // At this point the function level context is canceled, + // so it must not be used for the additional network + // request to cancel the query. + // Create a new context to pass into the dial. + ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + _ = cn.cancel(ctxCancel) + finished <- struct{}{} + case <-finished: + } + }() + return func() { + select { + case <-finished: + case finished <- struct{}{}: + } + } + } + return nil +} + +func (cn *conn) cancel(ctx context.Context) error { + c, err := dial(ctx, cn.dialer, cn.opts) + if err != nil { + return err + } + defer c.Close() + + { + can := conn{ + c: c, + } + err = can.ssl(cn.opts) + if err != nil { + return err + } + + w := can.writeBuf(0) + w.int32(80877102) // cancel request code + w.int32(cn.processID) + w.int32(cn.secretKey) + + if err := can.sendStartupPacket(w); err != nil { + return err + } + } + + // Read until EOF to ensure that the server received the cancel. + { + _, err := io.Copy(ioutil.Discard, c) + return err + } +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/conn_test.go golang-github-lib-pq-1.3.0/conn_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/conn_test.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/conn_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -1,10 +1,12 @@ package pq import ( + "context" "database/sql" "database/sql/driver" "fmt" "io" + "net" "os" "reflect" "strings" @@ -27,7 +29,7 @@ } } -func openTestConnConninfo(conninfo string) (*sql.DB, error) { +func testConninfo(conninfo string) string { defaultTo := func(envvar string, value string) { if os.Getenv(envvar) == "" { os.Setenv(envvar, value) @@ -40,10 +42,13 @@ if forceBinaryParameters() && !strings.HasPrefix(conninfo, "postgres://") && !strings.HasPrefix(conninfo, "postgresql://") { - conninfo = conninfo + " binary_parameters=yes" + conninfo += " binary_parameters=yes" } + return conninfo +} - return sql.Open("postgres", conninfo) +func openTestConnConninfo(conninfo string) (*sql.DB, error) { + return sql.Open("postgres", testConninfo(conninfo)) } func openTestConn(t Fatalistic) *sql.DB { @@ -135,6 +140,95 @@ testURL("postgresql://") } +const pgpassFile = "/tmp/pqgotest_pgpass" + +func TestPgpass(t *testing.T) { + if os.Getenv("TRAVIS") != "true" { + t.Skip("not running under Travis, skipping pgpass tests") + } + + testAssert := func(conninfo string, expected string, reason string) { + conn, err := openTestConnConninfo(conninfo) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + txn, err := conn.Begin() + if err != nil { + if expected != "fail" { + t.Fatalf(reason, err) + } + return + } + rows, err := txn.Query("SELECT USER") + if err != nil { + txn.Rollback() + if expected != "fail" { + t.Fatalf(reason, err) + } + } else { + rows.Close() + if expected != "ok" { + t.Fatalf(reason, err) + } + } + txn.Rollback() + } + testAssert("", "ok", "missing .pgpass, unexpected error %#v") + os.Setenv("PGPASSFILE", pgpassFile) + testAssert("host=/tmp", "fail", ", unexpected error %#v") + os.Remove(pgpassFile) + pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644) + if err != nil { + t.Fatalf("Unexpected error writing pgpass file %#v", err) + } + _, err = pgpass.WriteString(`# comment +server:5432:some_db:some_user:pass_A +*:5432:some_db:some_user:pass_B +localhost:*:*:*:pass_C +*:*:*:*:pass_fallback +`) + if err != nil { + t.Fatalf("Unexpected error writing pgpass file %#v", err) + } + pgpass.Close() + + assertPassword := func(extra values, expected string) { + o := values{ + "host": "localhost", + "sslmode": "disable", + "connect_timeout": "20", + "user": "majid", + "port": "5432", + "extra_float_digits": "2", + "dbname": "pqgotest", + "client_encoding": "UTF8", + "datestyle": "ISO, MDY", + } + for k, v := range extra { + o[k] = v + } + (&conn{}).handlePgpass(o) + if pw := o["password"]; pw != expected { + t.Fatalf("For %v expected %s got %s", extra, expected, pw) + } + } + // wrong permissions for the pgpass file means it should be ignored + assertPassword(values{"host": "example.com", "user": "foo"}, "") + // fix the permissions and check if it has taken effect + os.Chmod(pgpassFile, 0600) + assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A") + assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback") + assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B") + // localhost also matches the default "" and UNIX sockets + assertPassword(values{"host": "", "user": "some_user"}, "pass_C") + assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C") + // cleanup + os.Remove(pgpassFile) + os.Setenv("PGPASSFILE", "") +} + func TestExec(t *testing.T) { db := openTestConn(t) defer db.Close() @@ -296,10 +390,16 @@ db := openTestConn(t) defer db.Close() - _, err := db.Exec("") + res, err := db.Exec("") if err != nil { t.Fatal(err) } + if _, err := res.RowsAffected(); err != errNoRowsAffected { + t.Fatalf("expected %s, got %v", errNoRowsAffected, err) + } + if _, err := res.LastInsertId(); err != errNoLastInsertID { + t.Fatalf("expected %s, got %v", errNoLastInsertID, err) + } rows, err := db.Query("") if err != nil { t.Fatal(err) @@ -322,10 +422,16 @@ if err != nil { t.Fatal(err) } - _, err = stmt.Exec() + res, err = stmt.Exec() if err != nil { t.Fatal(err) } + if _, err := res.RowsAffected(); err != errNoRowsAffected { + t.Fatalf("expected %s, got %v", errNoRowsAffected, err) + } + if _, err := res.LastInsertId(); err != errNoLastInsertID { + t.Fatalf("expected %s, got %v", errNoLastInsertID, err) + } rows, err = stmt.Query() if err != nil { t.Fatal(err) @@ -345,6 +451,59 @@ } } +// Test that rows.Columns() is correct even if there are no result rows. +func TestEmptyResultSetColumns(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar WHERE FALSE") + if err != nil { + t.Fatal(err) + } + cols, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + if len(cols) != 2 { + t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) + } + if rows.Next() { + t.Fatal("unexpected row") + } + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + if cols[0] != "a" || cols[1] != "bar" { + t.Fatalf("unexpected Columns result %v", cols) + } + + stmt, err := db.Prepare("SELECT $1::int AS a, text 'bar' AS bar WHERE FALSE") + if err != nil { + t.Fatal(err) + } + rows, err = stmt.Query(1) + if err != nil { + t.Fatal(err) + } + cols, err = rows.Columns() + if err != nil { + t.Fatal(err) + } + if len(cols) != 2 { + t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) + } + if rows.Next() { + t.Fatal("unexpected row") + } + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + if cols[0] != "a" || cols[1] != "bar" { + t.Fatalf("unexpected Columns result %v", cols) + } + +} + func TestEncodeDecode(t *testing.T) { db := openTestConn(t) defer db.Close() @@ -482,6 +641,57 @@ } } +type testConn struct { + closed bool + net.Conn +} + +func (c *testConn) Close() error { + c.closed = true + return c.Conn.Close() +} + +type testDialer struct { + conns []*testConn +} + +func (d *testDialer) Dial(ntw, addr string) (net.Conn, error) { + c, err := net.Dial(ntw, addr) + if err != nil { + return nil, err + } + tc := &testConn{Conn: c} + d.conns = append(d.conns, tc) + return tc, nil +} + +func (d *testDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) { + c, err := net.DialTimeout(ntw, addr, timeout) + if err != nil { + return nil, err + } + tc := &testConn{Conn: c} + d.conns = append(d.conns, tc) + return tc, nil +} + +func TestErrorDuringStartupClosesConn(t *testing.T) { + // Don't use the normal connection setup, this is intended to + // blow up in the startup packet from a non-existent user. + var d testDialer + c, err := DialOpen(&d, testConninfo("user=thisuserreallydoesntexist")) + if err == nil { + c.Close() + t.Fatal("expected dial error") + } + if len(d.conns) != 1 { + t.Fatalf("got len(d.conns) = %d, want = %d", len(d.conns), 1) + } + if !d.conns[0].closed { + t.Error("connection leaked") + } +} + func TestBadConn(t *testing.T) { var err error @@ -511,6 +721,59 @@ } } +// TestCloseBadConn tests that the underlying connection can be closed with +// Close after an error. +func TestCloseBadConn(t *testing.T) { + host := os.Getenv("PGHOST") + if host == "" { + host = "localhost" + } + port := os.Getenv("PGPORT") + if port == "" { + port = "5432" + } + nc, err := net.Dial("tcp", host+":"+port) + if err != nil { + t.Fatal(err) + } + cn := conn{c: nc} + func() { + defer cn.errRecover(&err) + panic(io.EOF) + }() + // Verify we can write before closing. + if _, err := nc.Write(nil); err != nil { + t.Fatal(err) + } + // First close should close the connection. + if err := cn.Close(); err != nil { + t.Fatal(err) + } + + // During the Go 1.9 cycle, https://github.com/golang/go/commit/3792db5 + // changed this error from + // + // net.errClosing = errors.New("use of closed network connection") + // + // to + // + // internal/poll.ErrClosing = errors.New("use of closed file or network connection") + const errClosing = "use of closed" + + // Verify write after closing fails. + if _, err := nc.Write(nil); err == nil { + t.Fatal("expected error") + } else if !strings.Contains(err.Error(), errClosing) { + t.Fatalf("expected %s error, got %s", errClosing, err) + } + // Verify second close fails. + if err := cn.Close(); err == nil { + t.Fatal("expected error") + } else if !strings.Contains(err.Error(), errClosing) { + t.Fatalf("expected %s error, got %s", errClosing, err) + } +} + func TestErrorOnExec(t *testing.T) { db := openTestConn(t) defer db.Close() @@ -735,12 +998,14 @@ db := openTestConn(t) defer db.Close() - rows, err := db.Query("PARSE_ERROR $1", 1) - if err == nil { - t.Fatal("expected error") + _, err := db.Query("PARSE_ERROR $1", 1) + pqErr, _ := err.(*Error) + // Expecting a syntax error. + if err == nil || pqErr == nil || pqErr.Code != "42601" { + t.Fatalf("expected syntax error, got %s", err) } - rows, err = db.Query("SELECT 1") + rows, err := db.Query("SELECT 1") if err != nil { t.Fatal(err) } @@ -853,16 +1118,16 @@ db := openTestConn(t) defer db.Close() - var search_path string + var searchPath string err := db.QueryRow(` SET LOCAL search_path TO pg_catalog; SET LOCAL search_path TO pg_catalog; - SHOW search_path`).Scan(&search_path) + SHOW search_path`).Scan(&searchPath) if err != nil { t.Fatal(err) } - if search_path != "pg_catalog" { - t.Fatalf("unexpected search_path %s", search_path) + if searchPath != "pg_catalog" { + t.Fatalf("unexpected search_path %s", searchPath) } } @@ -870,10 +1135,11 @@ db := openTestConn(t) defer db.Close() - row := db.QueryRow("SELECT float4 '0.10000122', float8 '35.03554004971999'") + row := db.QueryRow("SELECT float4 '0.10000122', float8 '35.03554004971999', float4 '1.2'") var float4val float32 var float8val float64 - err := row.Scan(&float4val, &float8val) + var float4val2 float64 + err := row.Scan(&float4val, &float8val, &float4val2) if err != nil { t.Fatal(err) } @@ -883,6 +1149,9 @@ if float8val != float64(35.03554004971999) { t.Errorf("Expected float8 fidelity to be maintained; got no match") } + if float4val2 != float64(1.2) { + t.Errorf("Expected float4 fidelity into a float64 to be maintained; got no match") + } } func TestXactMultiStmt(t *testing.T) { @@ -1005,16 +1274,11 @@ tpc("SELECT foo", "", 0, true) // invalid row count } -func TestExecerInterface(t *testing.T) { - // Gin up a straw man private struct just for the type check - cn := &conn{c: nil} - var cni interface{} = cn - - _, ok := cni.(driver.Execer) - if !ok { - t.Fatal("Driver doesn't implement Execer") - } -} +// Test interface conformance. +var ( + _ driver.ExecerContext = (*conn)(nil) + _ driver.QueryerContext = (*conn)(nil) +) func TestNullAfterNonNull(t *testing.T) { db := openTestConn(t) @@ -1192,36 +1456,29 @@ } func TestRuntimeParameters(t *testing.T) { - type RuntimeTestResult int - const ( - ResultUnknown RuntimeTestResult = iota - ResultSuccess - ResultError // other error - ) - tests := []struct { - conninfo string - param string - expected string - expectedOutcome RuntimeTestResult + conninfo string + param string + expected string + success bool }{ // invalid parameter - {"DOESNOTEXIST=foo", "", "", ResultError}, + {"DOESNOTEXIST=foo", "", "", false}, // we can only work with a specific value for these two - {"client_encoding=SQL_ASCII", "", "", ResultError}, - {"datestyle='ISO, YDM'", "", "", ResultError}, + {"client_encoding=SQL_ASCII", "", "", false}, + {"datestyle='ISO, YDM'", "", "", false}, // "options" should work exactly as it does in libpq - {"options='-c search_path=pqgotest'", "search_path", "pqgotest", ResultSuccess}, + {"options='-c search_path=pqgotest'", "search_path", "pqgotest", true}, // pq should override client_encoding in this case - {"options='-c client_encoding=SQL_ASCII'", "client_encoding", "UTF8", ResultSuccess}, + {"options='-c client_encoding=SQL_ASCII'", "client_encoding", "UTF8", true}, // allow client_encoding to be set explicitly - {"client_encoding=UTF8", "client_encoding", "UTF8", ResultSuccess}, + {"client_encoding=UTF8", "client_encoding", "UTF8", true}, // test a runtime parameter not supported by libpq - {"work_mem='139kB'", "work_mem", "139kB", ResultSuccess}, + {"work_mem='139kB'", "work_mem", "139kB", true}, // test fallback_application_name - {"application_name=foo fallback_application_name=bar", "application_name", "foo", ResultSuccess}, - {"application_name='' fallback_application_name=bar", "application_name", "", ResultSuccess}, - {"fallback_application_name=bar", "application_name", "bar", ResultSuccess}, + {"application_name=foo fallback_application_name=bar", "application_name", "foo", true}, + {"application_name='' fallback_application_name=bar", "application_name", "", true}, + {"fallback_application_name=bar", "application_name", "bar", true}, } for _, test := range tests { @@ -1236,23 +1493,23 @@ continue } - tryGetParameterValue := func() (value string, outcome RuntimeTestResult) { + tryGetParameterValue := func() (value string, success bool) { defer db.Close() row := db.QueryRow("SELECT current_setting($1)", test.param) err = row.Scan(&value) if err != nil { - return "", ResultError + return "", false } - return value, ResultSuccess + return value, true } - value, outcome := tryGetParameterValue() - if outcome != test.expectedOutcome && outcome == ResultError { + value, success := tryGetParameterValue() + if success != test.success && !test.success { t.Fatalf("%v: unexpected error: %v", test.conninfo, err) } - if outcome != test.expectedOutcome { + if success != test.success { t.Fatalf("unexpected outcome %v (was expecting %v) for conninfo \"%s\"", - outcome, test.expectedOutcome, test.conninfo) + success, test.success, test.conninfo) } if value != test.expected { t.Fatalf("bad value for %s: got %s, want %s with conninfo \"%s\"", @@ -1304,3 +1561,217 @@ } } } + +func TestQuoteLiteral(t *testing.T) { + var cases = []struct { + input string + want string + }{ + {`foo`, `'foo'`}, + {`foo bar baz`, `'foo bar baz'`}, + {`foo'bar`, `'foo''bar'`}, + {`foo\bar`, ` E'foo\\bar'`}, + {`foo\ba'r`, ` E'foo\\ba''r'`}, + {`foo"bar`, `'foo"bar'`}, + {`foo\x00bar`, ` E'foo\\x00bar'`}, + {`\x00foo`, ` E'\\x00foo'`}, + {`'`, `''''`}, + {`''`, `''''''`}, + {`\`, ` E'\\'`}, + {`'abc'; DROP TABLE users;`, `'''abc''; DROP TABLE users;'`}, + {`\'`, ` E'\\'''`}, + {`E'\''`, ` E'E''\\'''''`}, + {`e'\''`, ` E'e''\\'''''`}, + {`E'\'abc\'; DROP TABLE users;'`, ` E'E''\\''abc\\''; DROP TABLE users;'''`}, + {`e'\'abc\'; DROP TABLE users;'`, ` E'e''\\''abc\\''; DROP TABLE users;'''`}, + } + + for _, test := range cases { + got := QuoteLiteral(test.input) + if got != test.want { + t.Errorf("QuoteLiteral(%q) = %v want %v", test.input, got, test.want) + } + } +} + +func TestRowsResultTag(t *testing.T) { + type ResultTag interface { + Result() driver.Result + Tag() string + } + + tests := []struct { + query string + tag string + ra int64 + }{ + { + query: "CREATE TEMP TABLE temp (a int)", + tag: "CREATE TABLE", + }, + { + query: "INSERT INTO temp VALUES (1), (2)", + tag: "INSERT", + ra: 2, + }, + { + query: "SELECT 1", + }, + // A SELECT anywhere should take precedent. + { + query: "SELECT 1; INSERT INTO temp VALUES (1), (2)", + }, + { + query: "INSERT INTO temp VALUES (1), (2); SELECT 1", + }, + // Multiple statements that don't return rows should return the last tag. + { + query: "CREATE TEMP TABLE t (a int); DROP TABLE t", + tag: "DROP TABLE", + }, + // Ensure a rows-returning query in any position among various tags-returing + // statements will prefer the rows. + { + query: "SELECT 1; CREATE TEMP TABLE t (a int); DROP TABLE t", + }, + { + query: "CREATE TEMP TABLE t (a int); SELECT 1; DROP TABLE t", + }, + { + query: "CREATE TEMP TABLE t (a int); DROP TABLE t; SELECT 1", + }, + // Verify that an no-results query doesn't set the tag. + { + query: "CREATE TEMP TABLE t (a int); SELECT 1 WHERE FALSE; DROP TABLE t;", + }, + } + + // If this is the only test run, this will correct the connection string. + openTestConn(t).Close() + + conn, err := Open("") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + q := conn.(driver.QueryerContext) + + for _, test := range tests { + if rows, err := q.QueryContext(context.Background(), test.query, nil); err != nil { + t.Fatalf("%s: %s", test.query, err) + } else { + r := rows.(ResultTag) + if tag := r.Tag(); tag != test.tag { + t.Fatalf("%s: unexpected tag %q", test.query, tag) + } + res := r.Result() + if ra, _ := res.RowsAffected(); ra != test.ra { + t.Fatalf("%s: unexpected rows affected: %d", test.query, ra) + } + rows.Close() + } + } +} + +// TestQuickClose tests that closing a query early allows a subsequent query to work. +func TestQuickClose(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + rows, err := tx.Query("SELECT 1; SELECT 2;") + if err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + + var id int + if err := tx.QueryRow("SELECT 3").Scan(&id); err != nil { + t.Fatal(err) + } + if id != 3 { + t.Fatalf("unexpected %d", id) + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } +} + +func TestMultipleResult(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query(` + begin; + select * from information_schema.tables limit 1; + select * from information_schema.columns limit 2; + commit; + `) + if err != nil { + t.Fatal(err) + } + type set struct { + cols []string + rowCount int + } + buf := []*set{} + for { + cols, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + s := &set{ + cols: cols, + } + buf = append(buf, s) + + for rows.Next() { + s.rowCount++ + } + if !rows.NextResultSet() { + break + } + } + if len(buf) != 2 { + t.Fatalf("got %d sets, expected 2", len(buf)) + } + if len(buf[0].cols) == len(buf[1].cols) || len(buf[1].cols) == 0 { + t.Fatal("invalid cols size, expected different column count and greater then zero") + } + if buf[0].rowCount != 1 || buf[1].rowCount != 2 { + t.Fatal("incorrect number of rows returned") + } +} + +func TestCopyInStmtAffectedRows(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Exec("CREATE TEMP TABLE temp (a int)") + if err != nil { + t.Fatal(err) + } + + txn, err := db.BeginTx(context.TODO(), nil) + if err != nil { + t.Fatal(err) + } + + copyStmt, err := txn.Prepare(CopyIn("temp", "a")) + if err != nil { + t.Fatal(err) + } + + res, err := copyStmt.Exec() + if err != nil { + t.Fatal(err) + } + + res.RowsAffected() + res.LastInsertId() +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/copy.go golang-github-lib-pq-1.3.0/copy.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/copy.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/copy.go 2019-12-11 04:41:10.000000000 +0000 @@ -13,6 +13,7 @@ errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY") errCopyToNotSupported = errors.New("pq: COPY TO is not supported") errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction") + errCopyInProgress = errors.New("pq: COPY in progress") ) // CopyIn creates a COPY FROM statement which can be prepared with @@ -96,13 +97,13 @@ err = parseError(r) case 'Z': if err == nil { - cn.bad = true + ci.setBad() errorf("unexpected ReadyForQuery in response to COPY") } cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for copy query: %q", t) } } @@ -121,7 +122,7 @@ cn.processReadyForQuery(r) return nil, err default: - cn.bad = true + ci.setBad() errorf("unknown response for CopyFail: %q", t) } } @@ -142,7 +143,7 @@ var r readBuf t, err := ci.cn.recvMessage(&r) if err != nil { - ci.cn.bad = true + ci.setBad() ci.setError(err) ci.done <- true return @@ -160,7 +161,7 @@ err := parseError(&r) ci.setError(err) default: - ci.cn.bad = true + ci.setBad() ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t)) ci.done <- true return @@ -168,6 +169,19 @@ } } +func (ci *copyin) setBad() { + ci.Lock() + ci.cn.bad = true + ci.Unlock() +} + +func (ci *copyin) isBad() bool { + ci.Lock() + b := ci.cn.bad + ci.Unlock() + return b +} + func (ci *copyin) isErrorSet() bool { ci.Lock() isSet := (ci.err != nil) @@ -205,7 +219,7 @@ return nil, errCopyInClosed } - if ci.cn.bad { + if ci.isBad() { return nil, driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -215,9 +229,7 @@ } if len(v) == 0 { - err = ci.Close() - ci.closed = true - return nil, err + return driver.RowsAffected(0), ci.Close() } numValues := len(v) @@ -240,11 +252,12 @@ } func (ci *copyin) Close() (err error) { - if ci.closed { - return errCopyInClosed + if ci.closed { // Don't do anything, we're already closed + return nil } + ci.closed = true - if ci.cn.bad { + if ci.isBad() { return driver.ErrBadConn } defer ci.cn.errRecover(&err) @@ -259,6 +272,7 @@ } <-ci.done + ci.cn.inCopy = false if ci.isErrorSet() { err = ci.err diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/copy_test.go golang-github-lib-pq-1.3.0/copy_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/copy_test.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/copy_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -3,13 +3,14 @@ import ( "bytes" "database/sql" + "database/sql/driver" + "net" "strings" "testing" ) func TestCopyInStmt(t *testing.T) { - var stmt string - stmt = CopyIn("table name") + stmt := CopyIn("table name") if stmt != `COPY "table name" () FROM STDIN` { t.Fatal(stmt) } @@ -26,8 +27,7 @@ } func TestCopyInSchemaStmt(t *testing.T) { - var stmt string - stmt = CopyInSchema("schema name", "table name") + stmt := CopyInSchema("schema name", "table name") if stmt != `COPY "schema name"."table name" () FROM STDIN` { t.Fatal(stmt) } @@ -225,7 +225,7 @@ if text != "Héllö\n ☃!\r\t\\" { t.Fatal("unexpected result", text) } - if bytes.Compare(blob, []byte{0, 255, 9, 10, 13}) != 0 { + if !bytes.Equal(blob, []byte{0, 255, 9, 10, 13}) { t.Fatal("unexpected result", blob) } if nothing.Valid { @@ -381,6 +381,7 @@ if err != nil { t.Fatal(err) } + defer stmt.Close() _, err = db.Exec("SELECT pg_terminate_backend($1)", pid) if err != nil { @@ -400,17 +401,22 @@ if err == nil { t.Fatalf("expected error") } - pge, ok := err.(*Error) - if !ok { - t.Fatalf("expected *pq.Error, got %+#v", err) - } else if pge.Code.Name() != "admin_shutdown" { - t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name()) + switch pge := err.(type) { + case *Error: + if pge.Code.Name() != "admin_shutdown" { + t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name()) + } + case *net.OpError: + // ignore + default: + if err == driver.ErrBadConn { + // likely an EPIPE + } else { + t.Fatalf("unexpected error, got %+#v", err) + } } - err = stmt.Close() - if err != nil { - t.Fatal(err) - } + _ = stmt.Close() } func BenchmarkCopyIn(b *testing.B) { diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/changelog golang-github-lib-pq-1.3.0/debian/changelog --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/changelog 2018-11-29 20:38:51.000000000 +0000 +++ golang-github-lib-pq-1.3.0/debian/changelog 2020-01-17 13:05:51.000000000 +0000 @@ -1,3 +1,15 @@ +golang-github-lib-pq (1.3.0-1) unstable; urgency=medium + + [ Daniel Swarbrick ] + * New upstream release. + * Add myself to uploaders. + * Update Standards-Version to 4.4.1 (no changes). + * Update debhelper version to 12. + * d/gbp.conf: stop using pristine-tar. + * Add debian/watch file. + + -- Daniel Swarbrick Fri, 17 Jan 2020 13:05:51 +0000 + golang-github-lib-pq (0.0~git20151007.0.ffe986a-2) unstable; urgency=medium [ Paul Tagliamonte ] diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/compat golang-github-lib-pq-1.3.0/debian/compat --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/compat 2018-11-29 20:38:23.000000000 +0000 +++ golang-github-lib-pq-1.3.0/debian/compat 1970-01-01 00:00:00.000000000 +0000 @@ -1 +0,0 @@ -11 diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/control golang-github-lib-pq-1.3.0/debian/control --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/control 2018-11-29 20:38:39.000000000 +0000 +++ golang-github-lib-pq-1.3.0/debian/control 2020-01-17 13:05:51.000000000 +0000 @@ -1,11 +1,16 @@ Source: golang-github-lib-pq Section: devel Priority: optional +Testsuite: autopkgtest-pkg-go Maintainer: Debian Go Packaging Team Uploaders: Michael Stapelberg , - Tianon Gravi -Build-Depends: debhelper (>= 11~), dh-golang, golang-any -Standards-Version: 4.2.1 + Tianon Gravi , + Daniel Swarbrick , +Build-Depends: debhelper-compat (= 12), + dh-golang, + golang-any, +Standards-Version: 4.4.1 +Rules-Requires-Root: no Homepage: https://github.com/lib/pq Vcs-Browser: https://salsa.debian.org/go-team/packages/golang-github-lib-pq Vcs-Git: https://salsa.debian.org/go-team/packages/golang-github-lib-pq.git @@ -13,12 +18,13 @@ Package: golang-github-lib-pq-dev Architecture: all -Depends: ${shlibs:Depends}, ${misc:Depends} -Replaces: golang-pq-dev (<< 0.0~git20151007.0.ffe986a-1~) -Breaks: golang-pq-dev (<< 0.0~git20151007.0.ffe986a-1~) -Provides: golang-pq-dev -Description: pure Go postgres driver for Go’s database/sql package - After importing this package, you can connect to a postgres database from your +Depends: ${misc:Depends}, + ${shlibs:Depends}, +Replaces: golang-pq-dev (<< 0.0~git20151007.0.ffe986a-1~), +Breaks: golang-pq-dev (<< 0.0~git20151007.0.ffe986a-1~), +Provides: golang-pq-dev, +Description: Pure Go Postgres driver for Go’s database/sql package + After importing this package, you can connect to a Postgres database from your Go programs. This package does not depend on libpq-dev and does not need cgo, making it suitable for use when cross-compiling. . @@ -27,7 +33,8 @@ Package: golang-pq-dev Section: oldlibs Architecture: all -Depends: ${misc:Depends}, golang-github-lib-pq-dev +Depends: golang-github-lib-pq-dev, + ${misc:Depends}, Description: Transitional package for golang-github-lib-pq-dev This is a transitional package to ease upgrades to the golang-github-lib-pq-dev package. It can safely be removed. diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/gbp.conf golang-github-lib-pq-1.3.0/debian/gbp.conf --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/gbp.conf 2018-11-29 20:32:11.000000000 +0000 +++ golang-github-lib-pq-1.3.0/debian/gbp.conf 2020-01-17 13:05:51.000000000 +0000 @@ -1,2 +1,6 @@ [DEFAULT] -pristine-tar = True +debian-branch = debian/sid + +[buildpackage] +dist = DEP14 +upstream-tag = upstream/%(version)s diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/source/lintian-overrides golang-github-lib-pq-1.3.0/debian/source/lintian-overrides --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/source/lintian-overrides 2018-11-29 20:32:11.000000000 +0000 +++ golang-github-lib-pq-1.3.0/debian/source/lintian-overrides 1970-01-01 00:00:00.000000000 +0000 @@ -1,4 +0,0 @@ -# debian/watch is missing intentionally. There are no releases, new code goes -# directly into the repository. In case codesearch gets a new maintainer -# who does it differently, debian/watch can be added. -golang-github-lib-pq source: debian-watch-file-is-missing diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/tests/control golang-github-lib-pq-1.3.0/debian/tests/control --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/tests/control 2018-11-29 20:32:11.000000000 +0000 +++ golang-github-lib-pq-1.3.0/debian/tests/control 2020-01-17 13:05:51.000000000 +0000 @@ -1,3 +1,5 @@ Tests: go-test -Depends: @, postgresql, golang-go +Depends: golang-go, + postgresql, + @, Restrictions: needs-root diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/watch golang-github-lib-pq-1.3.0/debian/watch --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/debian/watch 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/debian/watch 2020-01-17 13:05:51.000000000 +0000 @@ -0,0 +1,3 @@ +version=4 +opts=filenamemangle=s/.+\/v?(\d\S+)\.tar\.gz/pq-$1\.tar\.gz/ \ + https://github.com/lib/pq/tags .*/v?(\d\S+)\.tar\.gz diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/doc.go golang-github-lib-pq-1.3.0/doc.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/doc.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/doc.go 2019-12-11 04:41:10.000000000 +0000 @@ -11,7 +11,8 @@ ) func main() { - db, err := sql.Open("postgres", "user=pqgotest dbname=pqgotest sslmode=verify-full") + connStr := "user=pqgotest dbname=pqgotest sslmode=verify-full" + db, err := sql.Open("postgres", connStr) if err != nil { log.Fatal(err) } @@ -23,7 +24,8 @@ You can also connect to a database using a URL. For example: - db, err := sql.Open("postgres", "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full") + connStr := "postgres://pqgotest:password@localhost/pqgotest?sslmode=verify-full" + db, err := sql.Open("postgres", connStr) Connection String Parameters @@ -43,21 +45,28 @@ * dbname - The name of the database to connect to * user - The user to sign in as * password - The user's password - * host - The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) + * host - The host to connect to. Values that start with / are for unix + domain sockets. (default is localhost) * port - The port to bind to. (default is 5432) - * sslmode - Whether or not to use SSL (default is require, this is not the default for libpq) + * sslmode - Whether or not to use SSL (default is require, this is not + the default for libpq) * fallback_application_name - An application_name to fall back to if one isn't provided. - * connect_timeout - Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. + * connect_timeout - Maximum wait for connection, in seconds. Zero or + not specified means wait indefinitely. * sslcert - Cert file location. The file must contain PEM encoded data. * sslkey - Key file location. The file must contain PEM encoded data. - * sslrootcert - The location of the root certificate file. The file must contain PEM encoded data. + * sslrootcert - The location of the root certificate file. The file + must contain PEM encoded data. Valid values for sslmode are: * disable - No SSL * require - Always SSL (skip verification) - * verify-ca - Always SSL (verify that the certificate presented by the server was signed by a trusted CA) - * verify-full - Always SSL (verify that the certification presented by the server was signed by a trusted CA and the server host name matches the one in the certificate) + * verify-ca - Always SSL (verify that the certificate presented by the + server was signed by a trusted CA) + * verify-full - Always SSL (verify that the certification presented by + the server was signed by a trusted CA and the server host name + matches the one in the certificate) See http://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING for more information about connection string parameters. @@ -68,7 +77,7 @@ A backslash will escape the next character in values: - "user=space\ man password='it\'s valid' + "user=space\ man password='it\'s valid'" Note that the connection parameter client_encoding (which sets the text encoding for the connection) may be set but must be "UTF8", @@ -86,9 +95,13 @@ establishment. Environment variables have a lower precedence than explicitly provided connection parameters. +The pgpass mechanism as described in http://www.postgresql.org/docs/current/static/libpq-pgpass.html +is supported, but on Windows PGPASSFILE must be specified explicitly. + Queries + database/sql does not dictate any specific format for parameter markers in query strings, and pq uses the Postgres-native ordinal markers, as shown above. The same marker can be reused for the same parameter: @@ -112,8 +125,30 @@ For additional instructions on querying see the documentation for the database/sql package. + +Data Types + + +Parameters pass through driver.DefaultParameterConverter before they are handled +by this package. When the binary_parameters connection option is enabled, +[]byte values are sent directly to the backend as data in binary format. + +This package returns the following types for values from the PostgreSQL backend: + + - integer types smallint, integer, and bigint are returned as int64 + - floating-point types real and double precision are returned as float64 + - character types char, varchar, and text are returned as string + - temporal types date, time, timetz, timestamp, and timestamptz are + returned as time.Time + - the boolean type is returned as bool + - the bytea type is returned as []byte + +All other types are returned directly from the backend as []byte values in text format. + + Errors + pq may return errors of type *pq.Error which can be interrogated for error details: if err, ok := err.(*pq.Error); ok { @@ -204,7 +239,7 @@ bytes by the PostgreSQL server. You can find a complete, working example of Listener usage at -http://godoc.org/github.com/lib/pq/listen_example. +https://godoc.org/github.com/lib/pq/example/listen. */ package pq diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/encode.go golang-github-lib-pq-1.3.0/encode.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/encode.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/encode.go 2019-12-11 04:41:10.000000000 +0000 @@ -5,6 +5,7 @@ "database/sql/driver" "encoding/binary" "encoding/hex" + "errors" "fmt" "math" "strconv" @@ -22,7 +23,6 @@ default: return encode(parameterStatus, x, oid.T_unknown) } - panic("not reached") } func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte { @@ -56,10 +56,13 @@ } func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} { - if f == formatBinary { + switch f { + case formatBinary: return binaryDecode(parameterStatus, s, typ) - } else { + case formatText: return textDecode(parameterStatus, s, typ) + default: + panic("not reached") } } @@ -73,9 +76,15 @@ return int64(int32(binary.BigEndian.Uint32(s))) case oid.T_int2: return int64(int16(binary.BigEndian.Uint16(s))) + case oid.T_uuid: + b, err := decodeUUIDBinary(s) + if err != nil { + panic(err) + } + return b default: - errorf("don't know how to decode binary parameter of type %u", uint32(typ)) + errorf("don't know how to decode binary parameter of type %d", uint32(typ)) } panic("not reached") @@ -83,8 +92,14 @@ func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} { switch typ { + case oid.T_char, oid.T_varchar, oid.T_text: + return string(s) case oid.T_bytea: - return parseBytea(s) + b, err := parseBytea(s) + if err != nil { + errorf("%s", err) + } + return b case oid.T_timestamptz: return parseTs(parameterStatus.currentLocation, string(s)) case oid.T_timestamp, oid.T_date: @@ -102,11 +117,10 @@ } return i case oid.T_float4, oid.T_float8: - bits := 64 - if typ == oid.T_float4 { - bits = 32 - } - f, err := strconv.ParseFloat(string(s), bits) + // We always use 64 bit parsing, regardless of whether the input text is for + // a float4 or float8, because clients expect float64s for all float datatypes + // and returning a 32-bit parsed float64 produces lossy results. + f, err := strconv.ParseFloat(string(s), 64) if err != nil { errorf("%s", err) } @@ -195,16 +209,39 @@ return t } -func expect(str, char string, pos int) { - if c := str[pos : pos+1]; c != char { - errorf("expected '%v' at position %v; got '%v'", char, pos, c) +var errInvalidTimestamp = errors.New("invalid timestamp") + +type timestampParser struct { + err error +} + +func (p *timestampParser) expect(str string, char byte, pos int) { + if p.err != nil { + return + } + if pos+1 > len(str) { + p.err = errInvalidTimestamp + return + } + if c := str[pos]; c != char && p.err == nil { + p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c) } } -func mustAtoi(str string) int { - result, err := strconv.Atoi(str) +func (p *timestampParser) mustAtoi(str string, begin int, end int) int { + if p.err != nil { + return 0 + } + if begin < 0 || end < 0 || begin > end || end > len(str) { + p.err = errInvalidTimestamp + return 0 + } + result, err := strconv.Atoi(str[begin:end]) if err != nil { - errorf("expected number; got '%v'", str) + if p.err == nil { + p.err = fmt.Errorf("expected number; got '%v'", str) + } + return 0 } return result } @@ -219,7 +256,7 @@ // about 5% speed could be gained by putting the cache in the connection and // losing the mutex, at the cost of a small amount of memory and a somewhat // significant increase in code complexity. -var globalLocationCache *locationCache = newLocationCache() +var globalLocationCache = newLocationCache() func newLocationCache() *locationCache { return &locationCache{cache: make(map[int]*time.Location)} @@ -249,26 +286,26 @@ infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive" ) -/* - * If EnableInfinityTs is not called, "-infinity" and "infinity" will return - * []byte("-infinity") and []byte("infinity") respectively, and potentially - * cause error "sql: Scan error on column index 0: unsupported driver -> Scan pair: []uint8 -> *time.Time", - * when scanning into a time.Time value. - * - * Once EnableInfinityTs has been called, all connections created using this - * driver will decode Postgres' "-infinity" and "infinity" for "timestamp", - * "timestamp with time zone" and "date" types to the predefined minimum and - * maximum times, respectively. When encoding time.Time values, any time which - * equals or preceeds the predefined minimum time will be encoded to - * "-infinity". Any values at or past the maximum time will similarly be - * encoded to "infinity". - * - * - * If EnableInfinityTs is called with negative >= positive, it will panic. - * Calling EnableInfinityTs after a connection has been established results in - * undefined behavior. If EnableInfinityTs is called more than once, it will - * panic. - */ +// EnableInfinityTs controls the handling of Postgres' "-infinity" and +// "infinity" "timestamp"s. +// +// If EnableInfinityTs is not called, "-infinity" and "infinity" will return +// []byte("-infinity") and []byte("infinity") respectively, and potentially +// cause error "sql: Scan error on column index 0: unsupported driver -> Scan +// pair: []uint8 -> *time.Time", when scanning into a time.Time value. +// +// Once EnableInfinityTs has been called, all connections created using this +// driver will decode Postgres' "-infinity" and "infinity" for "timestamp", +// "timestamp with time zone" and "date" types to the predefined minimum and +// maximum times, respectively. When encoding time.Time values, any time which +// equals or precedes the predefined minimum time will be encoded to +// "-infinity". Any values at or past the maximum time will similarly be +// encoded to "infinity". +// +// If EnableInfinityTs is called with negative >= positive, it will panic. +// Calling EnableInfinityTs after a connection has been established results in +// undefined behavior. If EnableInfinityTs is called more than once, it will +// panic. func EnableInfinityTs(negative time.Time, positive time.Time) { if infinityTsEnabled { panic(infinityTsEnabledAlready) @@ -305,28 +342,48 @@ } return []byte(str) } + t, err := ParseTimestamp(currentLocation, str) + if err != nil { + panic(err) + } + return t +} + +// ParseTimestamp parses Postgres' text format. It returns a time.Time in +// currentLocation iff that time's offset agrees with the offset sent from the +// Postgres server. Otherwise, ParseTimestamp returns a time.Time with the +// fixed offset offset provided by the Postgres server. +func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) { + p := timestampParser{} monSep := strings.IndexRune(str, '-') // this is Gregorian year, not ISO Year // In Gregorian system, the year 1 BC is followed by AD 1 - year := mustAtoi(str[:monSep]) + year := p.mustAtoi(str, 0, monSep) daySep := monSep + 3 - month := mustAtoi(str[monSep+1 : daySep]) - expect(str, "-", daySep) + month := p.mustAtoi(str, monSep+1, daySep) + p.expect(str, '-', daySep) timeSep := daySep + 3 - day := mustAtoi(str[daySep+1 : timeSep]) + day := p.mustAtoi(str, daySep+1, timeSep) + + minLen := monSep + len("01-01") + 1 + + isBC := strings.HasSuffix(str, " BC") + if isBC { + minLen += 3 + } var hour, minute, second int - if len(str) > monSep+len("01-01")+1 { - expect(str, " ", timeSep) + if len(str) > minLen { + p.expect(str, ' ', timeSep) minSep := timeSep + 3 - expect(str, ":", minSep) - hour = mustAtoi(str[timeSep+1 : minSep]) + p.expect(str, ':', minSep) + hour = p.mustAtoi(str, timeSep+1, minSep) secSep := minSep + 3 - expect(str, ":", secSep) - minute = mustAtoi(str[minSep+1 : secSep]) + p.expect(str, ':', secSep) + minute = p.mustAtoi(str, minSep+1, secSep) secEnd := secSep + 3 - second = mustAtoi(str[secSep+1 : secEnd]) + second = p.mustAtoi(str, secSep+1, secEnd) } remainderIdx := monSep + len("01-01 00:00:00") + 1 // Three optional (but ordered) sections follow: the @@ -337,49 +394,51 @@ nanoSec := 0 tzOff := 0 - if remainderIdx < len(str) && str[remainderIdx:remainderIdx+1] == "." { + if remainderIdx < len(str) && str[remainderIdx] == '.' { fracStart := remainderIdx + 1 fracOff := strings.IndexAny(str[fracStart:], "-+ ") if fracOff < 0 { fracOff = len(str) - fracStart } - fracSec := mustAtoi(str[fracStart : fracStart+fracOff]) + fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff) nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff)))) remainderIdx += fracOff + 1 } - if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart:tzStart+1] == "-" || str[tzStart:tzStart+1] == "+") { + if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') { // time zone separator is always '-' or '+' (UTC is +00) var tzSign int - if c := str[tzStart : tzStart+1]; c == "-" { + switch c := str[tzStart]; c { + case '-': tzSign = -1 - } else if c == "+" { + case '+': tzSign = +1 - } else { - errorf("expected '-' or '+' at position %v; got %v", tzStart, c) + default: + return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c) } - tzHours := mustAtoi(str[tzStart+1 : tzStart+3]) + tzHours := p.mustAtoi(str, tzStart+1, tzStart+3) remainderIdx += 3 var tzMin, tzSec int - if tzStart+3 < len(str) && str[tzStart+3:tzStart+4] == ":" { - tzMin = mustAtoi(str[tzStart+4 : tzStart+6]) + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) remainderIdx += 3 } - if tzStart+6 < len(str) && str[tzStart+6:tzStart+7] == ":" { - tzSec = mustAtoi(str[tzStart+7 : tzStart+9]) + if remainderIdx < len(str) && str[remainderIdx] == ':' { + tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3) remainderIdx += 3 } tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec) } var isoYear int - if remainderIdx < len(str) && str[remainderIdx:remainderIdx+3] == " BC" { + + if isBC { isoYear = 1 - year remainderIdx += 3 } else { isoYear = year } if remainderIdx < len(str) { - errorf("expected end of input, got %v", str[remainderIdx:]) + return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:]) } t := time.Date(isoYear, time.Month(month), day, hour, minute, second, nanoSec, @@ -396,11 +455,11 @@ } } - return t + return t, p.err } // formatTs formats t into a format postgres understands. -func formatTs(t time.Time) (b []byte) { +func formatTs(t time.Time) []byte { if infinityTsEnabled { // t <= -infinity : ! (t > -infinity) if !t.After(infinityTsNegative) { @@ -411,6 +470,11 @@ return []byte("infinity") } } + return FormatTimestamp(t) +} + +// FormatTimestamp formats t into Postgres' text format for timestamps. +func FormatTimestamp(t time.Time) []byte { // Need to send dates before 0001 A.D. with " BC" suffix, instead of the // minus sign preferred by Go. // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on @@ -420,10 +484,10 @@ t = t.AddDate((-t.Year())*2+1, 0, 0) bc = true } - b = []byte(t.Format(time.RFC3339Nano)) + b := []byte(t.Format("2006-01-02 15:04:05.999999999Z07:00")) _, offset := t.Zone() - offset = offset % 60 + offset %= 60 if offset != 0 { // RFC3339Nano already printed the minus sign if offset < 0 { @@ -445,14 +509,14 @@ // Parse a bytea value received from the server. Both "hex" and the legacy // "escape" format are supported. -func parseBytea(s []byte) (result []byte) { +func parseBytea(s []byte) (result []byte, err error) { if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { // bytea_output = hex s = s[2:] // trim off leading "\\x" result = make([]byte, hex.DecodedLen(len(s))) _, err := hex.Decode(result, s) if err != nil { - errorf("%s", err) + return nil, err } } else { // bytea_output = escape @@ -467,11 +531,11 @@ // '\\' followed by an octal number if len(s) < 4 { - errorf("invalid bytea sequence %v", s) + return nil, fmt.Errorf("invalid bytea sequence %v", s) } r, err := strconv.ParseInt(string(s[1:4]), 8, 9) if err != nil { - errorf("could not parse bytea value: %s", err.Error()) + return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) } result = append(result, byte(r)) s = s[4:] @@ -489,7 +553,7 @@ } } - return result + return result, nil } func encodeBytea(serverVersion int, v []byte) (result []byte) { diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/encode_test.go golang-github-lib-pq-1.3.0/encode_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/encode_test.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/encode_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -4,6 +4,7 @@ "bytes" "database/sql" "fmt" + "regexp" "testing" "time" @@ -36,6 +37,8 @@ }{ {"22001-02-03", time.Date(22001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, {"2001-02-03", time.Date(2001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, + {"0001-12-31 BC", time.Date(0, time.December, 31, 0, 0, 0, 0, time.FixedZone("", 0))}, + {"2001-02-03 BC", time.Date(-2000, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, {"2001-02-03 04:05:06", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, {"2001-02-03 04:05:06.000001", time.Date(2001, time.February, 3, 4, 5, 6, 1000, time.FixedZone("", 0))}, {"2001-02-03 04:05:06.00001", time.Date(2001, time.February, 3, 4, 5, 6, 10000, time.FixedZone("", 0))}, @@ -71,26 +74,10 @@ {"123456-02-03 04:05:06.1", time.Date(123456, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, } -// Helper function for the two tests below -func tryParse(str string) (t time.Time, err error) { - defer func() { - if p := recover(); p != nil { - err = fmt.Errorf("%v", p) - return - } - }() - i := parseTs(nil, str) - t, ok := i.(time.Time) - if !ok { - err = fmt.Errorf("Not a time.Time type, got %#v", i) - } - return -} - // Test that parsing the string results in the expected value. func TestParseTs(t *testing.T) { for i, tt := range timeTests { - val, err := tryParse(tt.str) + val, err := ParseTimestamp(nil, tt.str) if err != nil { t.Errorf("%d: got error: %v", i, err) } else if val.String() != tt.timeval.String() { @@ -100,6 +87,36 @@ } } +var timeErrorTests = []string{ + "BC", + " BC", + "2001", + "2001-2-03", + "2001-02-3", + "2001-02-03 ", + "2001-02-03 B", + "2001-02-03 04", + "2001-02-03 04:", + "2001-02-03 04:05", + "2001-02-03 04:05 B", + "2001-02-03 04:05 BC", + "2001-02-03 04:05:", + "2001-02-03 04:05:6", + "2001-02-03 04:05:06 B", + "2001-02-03 04:05:06BC", + "2001-02-03 04:05:06.123 B", +} + +// Test that parsing the string results in an error. +func TestParseTsErrors(t *testing.T) { + for i, tt := range timeErrorTests { + _, err := ParseTimestamp(nil, tt) + if err == nil { + t.Errorf("%d: expected an error from parsing: %v", i, tt) + } + } +} + // Now test that sending the value into the database and parsing it back // returns the same time.Time value. func TestEncodeAndParseTs(t *testing.T) { @@ -117,7 +134,7 @@ continue } - val, err := tryParse(dbstr) + val, err := ParseTimestamp(nil, dbstr) if err != nil { t.Errorf("%d: could not parse value %q: %s", i, dbstr, err) continue @@ -133,22 +150,22 @@ time time.Time expected string }{ - {time.Time{}, "0001-01-01T00:00:00Z"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03T04:05:06.123456789Z"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03T04:05:06.123456789+02:00"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03T04:05:06.123456789-06:00"}, - {time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03T04:05:06-07:30:09"}, - - {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03T04:05:06.123456789Z"}, - {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03T04:05:06.123456789+02:00"}, - {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03T04:05:06.123456789-06:00"}, - - {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03T04:05:06.123456789Z BC"}, - {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03T04:05:06.123456789+02:00 BC"}, - {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03T04:05:06.123456789-06:00 BC"}, + {time.Time{}, "0001-01-01 00:00:00Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03 04:05:06.123456789Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03 04:05:06.123456789+02:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03 04:05:06.123456789-06:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03 04:05:06-07:30:09"}, + + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00"}, + + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00 BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00 BC"}, - {time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03T04:05:06-07:30:09"}, - {time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03T04:05:06-07:30:09 BC"}, + {time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09"}, + {time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09 BC"}, } func TestFormatTs(t *testing.T) { @@ -160,6 +177,26 @@ } } +func TestFormatTsBackend(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + var str string + err := db.QueryRow("SELECT '2001-02-03T04:05:06.007-08:09:10'::time::text").Scan(&str) + if err == nil { + t.Fatalf("PostgreSQL is accepting an ISO timestamp input for time") + } + + for i, tt := range formatTimeTests { + for _, typ := range []string{"date", "time", "timetz", "timestamp", "timestamptz"} { + err = db.QueryRow("SELECT $1::"+typ+"::text", tt.time).Scan(&str) + if err != nil { + t.Errorf("%d: incorrect time format for %v on the backend: %v", i, typ, err) + } + } + } +} + func TestTimestampWithTimeZone(t *testing.T) { db := openTestConn(t) defer db.Close() @@ -230,9 +267,7 @@ t.Fatalf("Could not run query: %v", err) } - n := r.Next() - - if n != true { + if !r.Next() { t.Fatal("Expected at least one row") } @@ -252,8 +287,7 @@ expected, result) } - n = r.Next() - if n != false { + if r.Next() { t.Fatal("Expected only one row") } } @@ -270,24 +304,27 @@ var err error var resultT time.Time - expectedError := fmt.Errorf(`sql: Scan error on column index 0: unsupported driver -> Scan pair: []uint8 -> *time.Time`) + expectedErrorStrRegexp := regexp.MustCompile( + `^sql: Scan error on column index 0(, name "timestamp(tz)?"|): unsupported`) + type testCases []struct { - Query string - Param string - ExpectedErr error - ExpectedVal interface{} + Query string + Param string + ExpectedErrorStrRegexp *regexp.Regexp + ExpectedVal interface{} } tc := testCases{ - {"SELECT $1::timestamp", "-infinity", expectedError, "-infinity"}, - {"SELECT $1::timestamptz", "-infinity", expectedError, "-infinity"}, - {"SELECT $1::timestamp", "infinity", expectedError, "infinity"}, - {"SELECT $1::timestamptz", "infinity", expectedError, "infinity"}, + {"SELECT $1::timestamp", "-infinity", expectedErrorStrRegexp, "-infinity"}, + {"SELECT $1::timestamptz", "-infinity", expectedErrorStrRegexp, "-infinity"}, + {"SELECT $1::timestamp", "infinity", expectedErrorStrRegexp, "infinity"}, + {"SELECT $1::timestamptz", "infinity", expectedErrorStrRegexp, "infinity"}, } // try to assert []byte to time.Time for _, q := range tc { err = db.QueryRow(q.Query, q.Param).Scan(&resultT) - if err.Error() != q.ExpectedErr.Error() { - t.Errorf("Scanning -/+infinity, expected error, %q, got %q", q.ExpectedErr, err) + if err == nil || !q.ExpectedErrorStrRegexp.MatchString(err.Error()) { + t.Errorf("Scanning -/+infinity, expected error to match regexp %q, got %q", + q.ExpectedErrorStrRegexp, err) } } // yield []byte @@ -342,17 +379,17 @@ t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String()) } - y_1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC) + ym1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC) y11500 := time.Date(11500, time.January, 1, 0, 0, 0, 0, time.UTC) var s string - err = db.QueryRow("SELECT $1::timestamp::text", y_1500).Scan(&s) + err = db.QueryRow("SELECT $1::timestamp::text", ym1500).Scan(&s) if err != nil { t.Errorf("Encoding -infinity, expected no error, got %q", err) } if s != "-infinity" { t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s) } - err = db.QueryRow("SELECT $1::timestamptz::text", y_1500).Scan(&s) + err = db.QueryRow("SELECT $1::timestamptz::text", ym1500).Scan(&s) if err != nil { t.Errorf("Encoding -infinity, expected no error, got %q", err) } @@ -468,11 +505,11 @@ db := openTestConn(t) defer db.Close() - b := []byte{'\xa0','\xee','\xbc','\x99', - '\x9c', '\x0b', - '\x4e', '\xf8', - '\xbb', '\x00', '\x6b', - '\xb9', '\xbd', '\x38', '\x0a', '\x11'} + b := []byte{'\xa0', '\xee', '\xbc', '\x99', + '\x9c', '\x0b', + '\x4e', '\xf8', + '\xbb', '\x00', '\x6b', + '\xb9', '\xbd', '\x38', '\x0a', '\x11'} row := db.QueryRow("SELECT $1::uuid", b) var result string @@ -567,6 +604,17 @@ } } +func TestTextDecodeIntoString(t *testing.T) { + input := []byte("hello world") + want := string(input) + for _, typ := range []oid.Oid{oid.T_char, oid.T_varchar, oid.T_text} { + got := decode(¶meterStatus{}, input, typ, formatText) + if got != want { + t.Errorf("invalid string decoding output for %T(%+v), got %v but expected %v", typ, typ, got, want) + } + } +} + func TestByteaOutputFormatEncoding(t *testing.T) { input := []byte("\\x\x00\x01\x02\xFF\xFEabcdefg0123") want := []byte("\\x5c78000102fffe6162636465666730313233") @@ -683,8 +731,7 @@ } func TestAppendEscapedTextExistingBuffer(t *testing.T) { - var buf []byte - buf = []byte("123\t") + buf := []byte("123\t") if esc := appendEscapedText(buf, "hallo\tescape"); string(esc) != "123\thallo\\tescape" { t.Fatal(string(esc)) } diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/error.go golang-github-lib-pq-1.3.0/error.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/error.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/error.go 2019-12-11 04:41:10.000000000 +0000 @@ -153,6 +153,7 @@ "22004": "null_value_not_allowed", "22002": "null_value_no_indicator_parameter", "22003": "numeric_value_out_of_range", + "2200H": "sequence_generator_limit_exceeded", "22026": "string_data_length_mismatch", "22001": "string_data_right_truncation", "22011": "substring_error", @@ -459,6 +460,11 @@ panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) } +// TODO(ainar-g) Rename to errorf after removing panics. +func fmterrorf(s string, args ...interface{}) error { + return fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)) +} + func errRecoverNoErrBadConn(err *error) { e := recover() if e == nil { @@ -472,13 +478,13 @@ } } -func (c *conn) errRecover(err *error) { +func (cn *conn) errRecover(err *error) { e := recover() switch v := e.(type) { case nil: // Do nothing case runtime.Error: - c.bad = true + cn.bad = true panic(v) case *Error: if v.Fatal() { @@ -487,7 +493,8 @@ *err = v } case *net.OpError: - *err = driver.ErrBadConn + cn.bad = true + *err = v case error: if v == io.EOF || v.(error).Error() == "remote error: handshake failure" { *err = driver.ErrBadConn @@ -496,13 +503,13 @@ } default: - c.bad = true + cn.bad = true panic(fmt.Sprintf("unknown error: %#v", e)) } // Any time we return ErrBadConn, we need to remember it since *Tx doesn't // mark the connection bad in database/sql. if *err == driver.ErrBadConn { - c.bad = true + cn.bad = true } } diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/example/listen/doc.go golang-github-lib-pq-1.3.0/example/listen/doc.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/example/listen/doc.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/example/listen/doc.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,98 @@ +/* + +Package listen is a self-contained Go program which uses the LISTEN / NOTIFY +mechanism to avoid polling the database while waiting for more work to arrive. + + // + // You can see the program in action by defining a function similar to + // the following: + // + // CREATE OR REPLACE FUNCTION public.get_work() + // RETURNS bigint + // LANGUAGE sql + // AS $$ + // SELECT CASE WHEN random() >= 0.2 THEN int8 '1' END + // $$ + // ; + + package main + + import ( + "database/sql" + "fmt" + "time" + + "github.com/lib/pq" + ) + + func doWork(db *sql.DB, work int64) { + // work here + } + + func getWork(db *sql.DB) { + for { + // get work from the database here + var work sql.NullInt64 + err := db.QueryRow("SELECT get_work()").Scan(&work) + if err != nil { + fmt.Println("call to get_work() failed: ", err) + time.Sleep(10 * time.Second) + continue + } + if !work.Valid { + // no more work to do + fmt.Println("ran out of work") + return + } + + fmt.Println("starting work on ", work.Int64) + go doWork(db, work.Int64) + } + } + + func waitForNotification(l *pq.Listener) { + select { + case <-l.Notify: + fmt.Println("received notification, new work available") + case <-time.After(90 * time.Second): + go l.Ping() + // Check if there's more work available, just in case it takes + // a while for the Listener to notice connection loss and + // reconnect. + fmt.Println("received no work for 90 seconds, checking for new work") + } + } + + func main() { + var conninfo string = "" + + db, err := sql.Open("postgres", conninfo) + if err != nil { + panic(err) + } + + reportProblem := func(ev pq.ListenerEventType, err error) { + if err != nil { + fmt.Println(err.Error()) + } + } + + minReconn := 10 * time.Second + maxReconn := time.Minute + listener := pq.NewListener(conninfo, minReconn, maxReconn, reportProblem) + err = listener.Listen("getwork") + if err != nil { + panic(err) + } + + fmt.Println("entering main loop") + for { + // process all available work before waiting for notifications + getWork(db) + waitForNotification(listener) + } + } + + +*/ +package listen diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/.gitignore golang-github-lib-pq-1.3.0/.gitignore --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/.gitignore 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/.gitignore 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,4 @@ +.db +*.test +*~ +*.swp diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/go18_test.go golang-github-lib-pq-1.3.0/go18_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/go18_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/go18_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,319 @@ +package pq + +import ( + "context" + "database/sql" + "runtime" + "strings" + "testing" + "time" +) + +func TestMultipleSimpleQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("select 1; set time zone default; select 2; select 3") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + var i int + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 1 { + t.Fatalf("expected 1, got %d", i) + } + } + if !rows.NextResultSet() { + t.Fatal("expected more result sets", rows.Err()) + } + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 2 { + t.Fatalf("expected 2, got %d", i) + } + } + + // Make sure that if we ignore a result we can still query. + + rows, err = db.Query("select 4; select 5") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 4 { + t.Fatalf("expected 4, got %d", i) + } + } + if !rows.NextResultSet() { + t.Fatal("expected more result sets", rows.Err()) + } + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 5 { + t.Fatalf("expected 5, got %d", i) + } + } + if rows.NextResultSet() { + t.Fatal("unexpected result set") + } +} + +const contextRaceIterations = 100 + +func TestContextCancelExec(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(time.Millisecond*10, cancel).Stop() + + // Not canceled until after the exec has started. + if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "pq: canceling statement due to user request" { + t.Fatalf("unexpected error: %s", err) + } + + // Context is already canceled, so error should come before execution. + if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "context canceled" { + t.Fatalf("unexpected error: %s", err) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if _, err := db.ExecContext(ctx, "select 1"); err != nil { + t.Fatal(err) + } + }() + + if _, err := db.Exec("select 1"); err != nil { + t.Fatal(err) + } + } +} + +func TestContextCancelQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.QueryContext has begun. + defer time.AfterFunc(time.Millisecond*10, cancel).Stop() + + // Not canceled until after the exec has started. + if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "pq: canceling statement due to user request" { + t.Fatalf("unexpected error: %s", err) + } + + // Context is already canceled, so error should come before execution. + if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "context canceled" { + t.Fatalf("unexpected error: %s", err) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + rows, err := db.QueryContext(ctx, "select 1") + cancel() + if err != nil { + t.Fatal(err) + } else if err := rows.Close(); err != nil { + t.Fatal(err) + } + }() + + if rows, err := db.Query("select 1"); err != nil { + t.Fatal(err) + } else if err := rows.Close(); err != nil { + t.Fatal(err) + } + } +} + +// TestIssue617 tests that a failed query in QueryContext doesn't lead to a +// goroutine leak. +func TestIssue617(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + const N = 10 + + numGoroutineStart := runtime.NumGoroutine() + for i := 0; i < N; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := db.QueryContext(ctx, `SELECT * FROM DOESNOTEXIST`) + pqErr, _ := err.(*Error) + // Expecting "pq: relation \"doesnotexist\" does not exist" error. + if err == nil || pqErr == nil || pqErr.Code != "42P01" { + t.Fatalf("expected undefined table error, got %v", err) + } + }() + } + numGoroutineFinish := runtime.NumGoroutine() + + // We use N/2 and not N because the GC and other actors may increase or + // decrease the number of goroutines. + if numGoroutineFinish-numGoroutineStart >= N/2 { + t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) + } +} + +func TestContextCancelBegin(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + // Delay execution for just a bit until tx.Exec has begun. + defer time.AfterFunc(time.Millisecond*10, cancel).Stop() + + // Not canceled until after the exec has started. + if _, err := tx.Exec("select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "pq: canceling statement due to user request" { + t.Fatalf("unexpected error: %s", err) + } + + // Transaction is canceled, so expect an error. + if _, err := tx.Query("select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err != sql.ErrTxDone { + t.Fatalf("unexpected error: %s", err) + } + + // Context is canceled, so cannot begin a transaction. + if _, err := db.BeginTx(ctx, nil); err == nil { + t.Fatal("expected error") + } else if err.Error() != "context canceled" { + t.Fatalf("unexpected error: %s", err) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + tx, err := db.BeginTx(ctx, nil) + cancel() + if err != nil { + t.Fatal(err) + } else if err := tx.Rollback(); err != nil && + err.Error() != "pq: canceling statement due to user request" && + err != sql.ErrTxDone { + t.Fatal(err) + } + }() + + if tx, err := db.Begin(); err != nil { + t.Fatal(err) + } else if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + } +} + +func TestTxOptions(t *testing.T) { + db := openTestConn(t) + defer db.Close() + ctx := context.Background() + + tests := []struct { + level sql.IsolationLevel + isolation string + }{ + { + level: sql.LevelDefault, + isolation: "", + }, + { + level: sql.LevelReadUncommitted, + isolation: "read uncommitted", + }, + { + level: sql.LevelReadCommitted, + isolation: "read committed", + }, + { + level: sql.LevelRepeatableRead, + isolation: "repeatable read", + }, + { + level: sql.LevelSerializable, + isolation: "serializable", + }, + } + + for _, test := range tests { + for _, ro := range []bool{true, false} { + tx, err := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: test.level, + ReadOnly: ro, + }) + if err != nil { + t.Fatal(err) + } + + var isolation string + err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&isolation) + if err != nil { + t.Fatal(err) + } + + if test.isolation != "" && isolation != test.isolation { + t.Errorf("wrong isolation level: %s != %s", isolation, test.isolation) + } + + var isRO string + err = tx.QueryRow("select current_setting('transaction_read_only')").Scan(&isRO) + if err != nil { + t.Fatal(err) + } + + if ro != (isRO == "on") { + t.Errorf("read/[write,only] not set: %t != %s for level %s", + ro, isRO, test.isolation) + } + + tx.Rollback() + } + } + + _, err := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelLinearizable, + }) + if err == nil { + t.Fatal("expected LevelLinearizable to fail") + } + if !strings.Contains(err.Error(), "isolation level not supported") { + t.Errorf("Expected error to mention isolation level, got %q", err) + } +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/go19_test.go golang-github-lib-pq-1.3.0/go19_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/go19_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/go19_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,98 @@ +// +build go1.9 + +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "reflect" + "testing" +) + +func TestPing(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + db := openTestConn(t) + defer db.Close() + + if _, ok := reflect.TypeOf(db).MethodByName("Conn"); !ok { + t.Skipf("Conn method undefined on type %T, skipping test (requires at least go1.9)", db) + } + + if err := db.PingContext(ctx); err != nil { + t.Fatal("expected Ping to succeed") + } + defer cancel() + + // grab a connection + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + // start a transaction and read backend pid of our connection + tx, err := conn.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelDefault, + ReadOnly: true, + }) + if err != nil { + t.Fatal(err) + } + + rows, err := tx.Query("SELECT pg_backend_pid()") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + // read the pid from result + var pid int + for rows.Next() { + if err := rows.Scan(&pid); err != nil { + t.Fatal(err) + } + } + if rows.Err() != nil { + t.Fatal(err) + } + // Fail the transaction and make sure we can still ping. + if _, err := tx.Query("INVALID SQL"); err == nil { + t.Fatal("expected error") + } + if err := conn.PingContext(ctx); err != nil { + t.Fatal(err) + } + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + // kill the process which handles our connection and test if the ping fails + if _, err := db.Exec("SELECT pg_terminate_backend($1)", pid); err != nil { + t.Fatal(err) + } + if err := conn.PingContext(ctx); err != driver.ErrBadConn { + t.Fatalf("expected error %s, instead got %s", driver.ErrBadConn, err) + } +} + +func TestCommitInFailedTransactionWithCancelContext(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + txn, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + rows, err := txn.Query("SELECT error") + if err == nil { + rows.Close() + t.Fatal("expected failure") + } + err = txn.Commit() + if err != ErrInFailedTransaction { + t.Fatalf("expected ErrInFailedTransaction; got %#v", err) + } +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/go.mod golang-github-lib-pq-1.3.0/go.mod --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/go.mod 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/go.mod 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1 @@ +module github.com/lib/pq diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/hstore/hstore.go golang-github-lib-pq-1.3.0/hstore/hstore.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/hstore/hstore.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/hstore/hstore.go 2019-12-11 04:41:10.000000000 +0000 @@ -6,7 +6,7 @@ "strings" ) -// A wrapper for transferring Hstore values back and forth easily. +// Hstore is a wrapper for transferring Hstore values back and forth easily. type Hstore struct { Map map[string]sql.NullString } diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/hstore/hstore_test.go golang-github-lib-pq-1.3.0/hstore/hstore_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/hstore/hstore_test.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/hstore/hstore_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -36,7 +36,7 @@ db := openTestConn(t) defer db.Close() - // quitely create hstore if it doesn't exist + // quietly create hstore if it doesn't exist _, err := db.Exec("CREATE EXTENSION IF NOT EXISTS hstore") if err != nil { t.Skipf("Skipping hstore tests - hstore extension create failed: %s", err.Error()) @@ -87,31 +87,31 @@ // a few example maps to test out hsOnePair := Hstore{ Map: map[string]sql.NullString{ - "key1": {"value1", true}, + "key1": {String: "value1", Valid: true}, }, } hsThreePairs := Hstore{ Map: map[string]sql.NullString{ - "key1": {"value1", true}, - "key2": {"value2", true}, - "key3": {"value3", true}, + "key1": {String: "value1", Valid: true}, + "key2": {String: "value2", Valid: true}, + "key3": {String: "value3", Valid: true}, }, } hsSmorgasbord := Hstore{ Map: map[string]sql.NullString{ - "nullstring": {"NULL", true}, - "actuallynull": {"", false}, - "NULL": {"NULL string key", true}, - "withbracket": {"value>42", true}, - "withequal": {"value=42", true}, - `"withquotes1"`: {`this "should" be fine`, true}, - `"withquotes"2"`: {`this "should\" also be fine`, true}, - "embedded1": {"value1=>x1", true}, - "embedded2": {`"value2"=>x2`, true}, - "withnewlines": {"\n\nvalue\t=>2", true}, - "<>": {`this, "should,\" also, => be fine`, true}, + "nullstring": {String: "NULL", Valid: true}, + "actuallynull": {String: "", Valid: false}, + "NULL": {String: "NULL string key", Valid: true}, + "withbracket": {String: "value>42", Valid: true}, + "withequal": {String: "value=42", Valid: true}, + `"withquotes1"`: {String: `this "should" be fine`, Valid: true}, + `"withquotes"2"`: {String: `this "should\" also be fine`, Valid: true}, + "embedded1": {String: "value1=>x1", Valid: true}, + "embedded2": {String: `"value2"=>x2`, Valid: true}, + "withnewlines": {String: "\n\nvalue\t=>2", Valid: true}, + "<>": {String: `this, "should,\" also, => be fine`, Valid: true}, }, } diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/issues_test.go golang-github-lib-pq-1.3.0/issues_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/issues_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/issues_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,26 @@ +package pq + +import "testing" + +func TestIssue494(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + query := `CREATE TEMP TABLE t (i INT PRIMARY KEY)` + if _, err := db.Exec(query); err != nil { + t.Fatal(err) + } + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + if _, err := txn.Prepare(CopyIn("t", "i")); err != nil { + t.Fatal(err) + } + + if _, err := txn.Query("SELECT 1"); err == nil { + t.Fatal("expected error") + } +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/listen_example/doc.go golang-github-lib-pq-1.3.0/listen_example/doc.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/listen_example/doc.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/listen_example/doc.go 1970-01-01 00:00:00.000000000 +0000 @@ -1,102 +0,0 @@ -/* - -Below you will find a self-contained Go program which uses the LISTEN / NOTIFY -mechanism to avoid polling the database while waiting for more work to arrive. - - // - // You can see the program in action by defining a function similar to - // the following: - // - // CREATE OR REPLACE FUNCTION public.get_work() - // RETURNS bigint - // LANGUAGE sql - // AS $$ - // SELECT CASE WHEN random() >= 0.2 THEN int8 '1' END - // $$ - // ; - - package main - - import ( - "database/sql" - "fmt" - "time" - - "github.com/lib/pq" - ) - - func doWork(db *sql.DB, work int64) { - // work here - } - - func getWork(db *sql.DB) { - for { - // get work from the database here - var work sql.NullInt64 - err := db.QueryRow("SELECT get_work()").Scan(&work) - if err != nil { - fmt.Println("call to get_work() failed: ", err) - time.Sleep(10 * time.Second) - continue - } - if !work.Valid { - // no more work to do - fmt.Println("ran out of work") - return - } - - fmt.Println("starting work on ", work.Int64) - go doWork(db, work.Int64) - } - } - - func waitForNotification(l *pq.Listener) { - for { - select { - case <-l.Notify: - fmt.Println("received notification, new work available") - return - case <-time.After(90 * time.Second): - go func() { - l.Ping() - }() - // Check if there's more work available, just in case it takes - // a while for the Listener to notice connection loss and - // reconnect. - fmt.Println("received no work for 90 seconds, checking for new work") - return - } - } - } - - func main() { - var conninfo string = "" - - db, err := sql.Open("postgres", conninfo) - if err != nil { - panic(err) - } - - reportProblem := func(ev pq.ListenerEventType, err error) { - if err != nil { - fmt.Println(err.Error()) - } - } - - listener := pq.NewListener(conninfo, 10 * time.Second, time.Minute, reportProblem) - err = listener.Listen("getwork") - if err != nil { - panic(err) - } - - fmt.Println("entering main loop") - for { - // process all available work before waiting for notifications - getWork(db) - waitForNotification(listener) - } - } - - -*/ -package listen_example diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/notify.go golang-github-lib-pq-1.3.0/notify.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/notify.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/notify.go 2019-12-11 04:41:10.000000000 +0000 @@ -60,16 +60,20 @@ replyChan chan message } -// Creates a new ListenerConn. Use NewListener instead. +// NewListenerConn creates a new ListenerConn. Use NewListener instead. func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) { - cn, err := Open(name) + return newDialListenerConn(defaultDialer{}, name, notificationChan) +} + +func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*ListenerConn, error) { + cn, err := DialOpen(d, name) if err != nil { return nil, err } l := &ListenerConn{ cn: cn.(*conn), - notificationChan: notificationChan, + notificationChan: c, connState: connStateIdle, replyChan: make(chan message, 2), } @@ -210,17 +214,17 @@ // this ListenerConn is done } -// Send a LISTEN query to the server. See ExecSimpleQuery. +// Listen sends a LISTEN query to the server. See ExecSimpleQuery. func (l *ListenerConn) Listen(channel string) (bool, error) { return l.ExecSimpleQuery("LISTEN " + QuoteIdentifier(channel)) } -// Send an UNLISTEN query to the server. See ExecSimpleQuery. +// Unlisten sends an UNLISTEN query to the server. See ExecSimpleQuery. func (l *ListenerConn) Unlisten(channel string) (bool, error) { return l.ExecSimpleQuery("UNLISTEN " + QuoteIdentifier(channel)) } -// Send `UNLISTEN *` to the server. See ExecSimpleQuery. +// UnlistenAll sends an `UNLISTEN *` query to the server. See ExecSimpleQuery. func (l *ListenerConn) UnlistenAll() (bool, error) { return l.ExecSimpleQuery("UNLISTEN *") } @@ -263,8 +267,8 @@ return nil } -// Execute a "simple query" (i.e. one with no bindable parameters) on the -// connection. The possible return values are: +// ExecSimpleQuery executes a "simple query" (i.e. one with no bindable +// parameters) on the connection. The possible return values are: // 1) "executed" is true; the query was executed to completion on the // database server. If the query failed, err will be set to the error // returned by the database, otherwise err will be nil. @@ -329,6 +333,7 @@ } } +// Close closes the connection. func (l *ListenerConn) Close() error { l.connectionLock.Lock() if l.err != nil { @@ -342,7 +347,7 @@ return l.cn.c.Close() } -// Err() returns the reason the connection was closed. It is not safe to call +// Err returns the reason the connection was closed. It is not safe to call // this function until l.Notify has been closed. func (l *ListenerConn) Err() error { return l.err @@ -350,32 +355,43 @@ var errListenerClosed = errors.New("pq: Listener has been closed") +// ErrChannelAlreadyOpen is returned from Listen when a channel is already +// open. var ErrChannelAlreadyOpen = errors.New("pq: channel is already open") + +// ErrChannelNotOpen is returned from Unlisten when a channel is not open. var ErrChannelNotOpen = errors.New("pq: channel is not open") +// ListenerEventType is an enumeration of listener event types. type ListenerEventType int const ( - // Emitted only when the database connection has been initially - // initialized. err will always be nil. + // ListenerEventConnected is emitted only when the database connection + // has been initially initialized. The err argument of the callback + // will always be nil. ListenerEventConnected ListenerEventType = iota - // Emitted after a database connection has been lost, either because of an - // error or because Close has been called. err will be set to the reason - // the database connection was lost. + // ListenerEventDisconnected is emitted after a database connection has + // been lost, either because of an error or because Close has been + // called. The err argument will be set to the reason the database + // connection was lost. ListenerEventDisconnected - // Emitted after a database connection has been re-established after - // connection loss. err will always be nil. After this event has been - // emitted, a nil pq.Notification is sent on the Listener.Notify channel. + // ListenerEventReconnected is emitted after a database connection has + // been re-established after connection loss. The err argument of the + // callback will always be nil. After this event has been emitted, a + // nil pq.Notification is sent on the Listener.Notify channel. ListenerEventReconnected - // Emitted after a connection to the database was attempted, but failed. - // err will be set to an error describing why the connection attempt did - // not succeed. + // ListenerEventConnectionAttemptFailed is emitted after a connection + // to the database was attempted, but failed. The err argument will be + // set to an error describing why the connection attempt did not + // succeed. ListenerEventConnectionAttemptFailed ) +// EventCallbackType is the event callback type. See also ListenerEventType +// constants' documentation. type EventCallbackType func(event ListenerEventType, err error) // Listener provides an interface for listening to notifications from a @@ -391,6 +407,7 @@ name string minReconnectInterval time.Duration maxReconnectInterval time.Duration + dialer Dialer eventCallback EventCallbackType lock sync.Mutex @@ -421,10 +438,21 @@ minReconnectInterval time.Duration, maxReconnectInterval time.Duration, eventCallback EventCallbackType) *Listener { + return NewDialListener(defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback) +} + +// NewDialListener is like NewListener but it takes a Dialer. +func NewDialListener(d Dialer, + name string, + minReconnectInterval time.Duration, + maxReconnectInterval time.Duration, + eventCallback EventCallbackType) *Listener { + l := &Listener{ name: name, minReconnectInterval: minReconnectInterval, maxReconnectInterval: maxReconnectInterval, + dialer: d, eventCallback: eventCallback, channels: make(map[string]struct{}), @@ -438,9 +466,9 @@ return l } -// Returns the notification channel for this listener. This is the same -// channel as Notify, and will not be recreated during the life time of the -// Listener. +// NotificationChannel returns the notification channel for this listener. +// This is the same channel as Notify, and will not be recreated during the +// life time of the Listener. func (l *Listener) NotificationChannel() <-chan *Notification { return l.Notify } @@ -609,7 +637,7 @@ // after the connection has been established. func (l *Listener) resync(cn *ListenerConn, notificationChan <-chan *Notification) error { doneChan := make(chan error) - go func() { + go func(notificationChan <-chan *Notification) { for channel := range l.channels { // If we got a response, return that error to our caller as it's // going to be more descriptive than cn.Err(). @@ -623,14 +651,14 @@ // close and then return the error message from the connection, as // per ListenerConn's interface. if err != nil { - for _ = range notificationChan { + for range notificationChan { } doneChan <- cn.Err() return } } doneChan <- nil - }() + }(notificationChan) // Ignore notifications while synchronization is going on to avoid // deadlocks. We have to send a nil notification over Notify anyway as @@ -660,7 +688,7 @@ func (l *Listener) connect() error { notificationChan := make(chan *Notification, 32) - cn, err := NewListenerConn(l.name, notificationChan) + cn, err := newDialListenerConn(l.dialer, l.name, notificationChan) if err != nil { return err } @@ -697,6 +725,9 @@ } l.isClosed = true + // Unblock calls to Listen() + l.reconnectCond.Broadcast() + return nil } @@ -756,7 +787,7 @@ } l.emitEvent(ListenerEventDisconnected, err) - time.Sleep(nextReconnect.Sub(time.Now())) + time.Sleep(time.Until(nextReconnect)) } } diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/notify_test.go golang-github-lib-pq-1.3.0/notify_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/notify_test.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/notify_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -7,7 +7,6 @@ "os" "runtime" "sync" - "sync/atomic" "testing" "time" ) @@ -124,6 +123,9 @@ } _, err = db.Exec("NOTIFY notify_test") + if err != nil { + t.Fatal(err) + } err = expectNotification(t, channel, "notify_test", "") if err != nil { @@ -160,6 +162,9 @@ } _, err = db.Exec("NOTIFY notify_test") + if err != nil { + t.Fatal(err) + } err = expectNotification(t, channel, "notify_test", "") if err != nil { @@ -235,15 +240,10 @@ // calls Close on the net.Conn; equivalent to a network failure l.Close() - var done int32 = 0 - go func() { - time.Sleep(10 * time.Second) - if atomic.LoadInt32(&done) != 1 { - panic("timed out") - } - }() + defer time.AfterFunc(10*time.Second, func() { + panic("timed out") + }).Stop() wg.Wait() - atomic.StoreInt32(&done, 1) } // Test for ListenerConn being closed while a slow query is executing @@ -271,15 +271,11 @@ if err != nil { t.Fatal(err) } - var done int32 = 0 - go func() { - time.Sleep(10 * time.Second) - if atomic.LoadInt32(&done) != 1 { - panic("timed out") - } - }() + + defer time.AfterFunc(10*time.Second, func() { + panic("timed out") + }).Stop() wg.Wait() - atomic.StoreInt32(&done, 1) } func TestNotifyExtra(t *testing.T) { diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/oid/gen.go golang-github-lib-pq-1.3.0/oid/gen.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/oid/gen.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/oid/gen.go 2019-12-11 04:41:10.000000000 +0000 @@ -10,10 +10,22 @@ "log" "os" "os/exec" + "strings" _ "github.com/lib/pq" ) +// OID represent a postgres Object Identifier Type. +type OID struct { + ID int + Type string +} + +// Name returns an upper case version of the oid type. +func (o OID) Name() string { + return strings.ToUpper(o.Type) +} + func main() { datname := os.Getenv("PGDATABASE") sslmode := os.Getenv("PGSSLMODE") @@ -30,6 +42,25 @@ if err != nil { log.Fatal(err) } + rows, err := db.Query(` + SELECT typname, oid + FROM pg_type WHERE oid < 10000 + ORDER BY oid; + `) + if err != nil { + log.Fatal(err) + } + oids := make([]*OID, 0) + for rows.Next() { + var oid OID + if err = rows.Scan(&oid.Type, &oid.ID); err != nil { + log.Fatal(err) + } + oids = append(oids, &oid) + } + if err = rows.Err(); err != nil { + log.Fatal(err) + } cmd := exec.Command("gofmt") cmd.Stderr = os.Stderr w, err := cmd.StdinPipe() @@ -45,30 +76,18 @@ if err != nil { log.Fatal(err) } - fmt.Fprintln(w, "// generated by 'go run gen.go'; do not edit") + fmt.Fprintln(w, "// Code generated by gen.go. DO NOT EDIT.") fmt.Fprintln(w, "\npackage oid") fmt.Fprintln(w, "const (") - rows, err := db.Query(` - SELECT typname, oid - FROM pg_type WHERE oid < 10000 - ORDER BY oid; - `) - if err != nil { - log.Fatal(err) - } - var name string - var oid int - for rows.Next() { - err = rows.Scan(&name, &oid) - if err != nil { - log.Fatal(err) - } - fmt.Fprintf(w, "T_%s Oid = %d\n", name, oid) - } - if err = rows.Err(); err != nil { - log.Fatal(err) + for _, oid := range oids { + fmt.Fprintf(w, "T_%s Oid = %d\n", oid.Type, oid.ID) } fmt.Fprintln(w, ")") + fmt.Fprintln(w, "var TypeName = map[Oid]string{") + for _, oid := range oids { + fmt.Fprintf(w, "T_%s: \"%s\",\n", oid.Type, oid.Name()) + } + fmt.Fprintln(w, "}") w.Close() cmd.Wait() } diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/oid/types.go golang-github-lib-pq-1.3.0/oid/types.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/oid/types.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/oid/types.go 2019-12-11 04:41:10.000000000 +0000 @@ -1,4 +1,4 @@ -// generated by 'go run gen.go'; do not edit +// Code generated by gen.go. DO NOT EDIT. package oid @@ -18,6 +18,7 @@ T_xid Oid = 28 T_cid Oid = 29 T_oidvector Oid = 30 + T_pg_ddl_command Oid = 32 T_pg_type Oid = 71 T_pg_attribute Oid = 75 T_pg_proc Oid = 81 @@ -28,6 +29,7 @@ T_pg_node_tree Oid = 194 T__json Oid = 199 T_smgr Oid = 210 + T_index_am_handler Oid = 325 T_point Oid = 600 T_lseg Oid = 601 T_path Oid = 602 @@ -133,6 +135,9 @@ T__uuid Oid = 2951 T_txid_snapshot Oid = 2970 T_fdw_handler Oid = 3115 + T_pg_lsn Oid = 3220 + T__pg_lsn Oid = 3221 + T_tsm_handler Oid = 3310 T_anyenum Oid = 3500 T_tsvector Oid = 3614 T_tsquery Oid = 3615 @@ -144,6 +149,8 @@ T__regconfig Oid = 3735 T_regdictionary Oid = 3769 T__regdictionary Oid = 3770 + T_jsonb Oid = 3802 + T__jsonb Oid = 3807 T_anyrange Oid = 3831 T_event_trigger Oid = 3838 T_int4range Oid = 3904 @@ -158,4 +165,179 @@ T__daterange Oid = 3913 T_int8range Oid = 3926 T__int8range Oid = 3927 + T_pg_shseclabel Oid = 4066 + T_regnamespace Oid = 4089 + T__regnamespace Oid = 4090 + T_regrole Oid = 4096 + T__regrole Oid = 4097 ) + +var TypeName = map[Oid]string{ + T_bool: "BOOL", + T_bytea: "BYTEA", + T_char: "CHAR", + T_name: "NAME", + T_int8: "INT8", + T_int2: "INT2", + T_int2vector: "INT2VECTOR", + T_int4: "INT4", + T_regproc: "REGPROC", + T_text: "TEXT", + T_oid: "OID", + T_tid: "TID", + T_xid: "XID", + T_cid: "CID", + T_oidvector: "OIDVECTOR", + T_pg_ddl_command: "PG_DDL_COMMAND", + T_pg_type: "PG_TYPE", + T_pg_attribute: "PG_ATTRIBUTE", + T_pg_proc: "PG_PROC", + T_pg_class: "PG_CLASS", + T_json: "JSON", + T_xml: "XML", + T__xml: "_XML", + T_pg_node_tree: "PG_NODE_TREE", + T__json: "_JSON", + T_smgr: "SMGR", + T_index_am_handler: "INDEX_AM_HANDLER", + T_point: "POINT", + T_lseg: "LSEG", + T_path: "PATH", + T_box: "BOX", + T_polygon: "POLYGON", + T_line: "LINE", + T__line: "_LINE", + T_cidr: "CIDR", + T__cidr: "_CIDR", + T_float4: "FLOAT4", + T_float8: "FLOAT8", + T_abstime: "ABSTIME", + T_reltime: "RELTIME", + T_tinterval: "TINTERVAL", + T_unknown: "UNKNOWN", + T_circle: "CIRCLE", + T__circle: "_CIRCLE", + T_money: "MONEY", + T__money: "_MONEY", + T_macaddr: "MACADDR", + T_inet: "INET", + T__bool: "_BOOL", + T__bytea: "_BYTEA", + T__char: "_CHAR", + T__name: "_NAME", + T__int2: "_INT2", + T__int2vector: "_INT2VECTOR", + T__int4: "_INT4", + T__regproc: "_REGPROC", + T__text: "_TEXT", + T__tid: "_TID", + T__xid: "_XID", + T__cid: "_CID", + T__oidvector: "_OIDVECTOR", + T__bpchar: "_BPCHAR", + T__varchar: "_VARCHAR", + T__int8: "_INT8", + T__point: "_POINT", + T__lseg: "_LSEG", + T__path: "_PATH", + T__box: "_BOX", + T__float4: "_FLOAT4", + T__float8: "_FLOAT8", + T__abstime: "_ABSTIME", + T__reltime: "_RELTIME", + T__tinterval: "_TINTERVAL", + T__polygon: "_POLYGON", + T__oid: "_OID", + T_aclitem: "ACLITEM", + T__aclitem: "_ACLITEM", + T__macaddr: "_MACADDR", + T__inet: "_INET", + T_bpchar: "BPCHAR", + T_varchar: "VARCHAR", + T_date: "DATE", + T_time: "TIME", + T_timestamp: "TIMESTAMP", + T__timestamp: "_TIMESTAMP", + T__date: "_DATE", + T__time: "_TIME", + T_timestamptz: "TIMESTAMPTZ", + T__timestamptz: "_TIMESTAMPTZ", + T_interval: "INTERVAL", + T__interval: "_INTERVAL", + T__numeric: "_NUMERIC", + T_pg_database: "PG_DATABASE", + T__cstring: "_CSTRING", + T_timetz: "TIMETZ", + T__timetz: "_TIMETZ", + T_bit: "BIT", + T__bit: "_BIT", + T_varbit: "VARBIT", + T__varbit: "_VARBIT", + T_numeric: "NUMERIC", + T_refcursor: "REFCURSOR", + T__refcursor: "_REFCURSOR", + T_regprocedure: "REGPROCEDURE", + T_regoper: "REGOPER", + T_regoperator: "REGOPERATOR", + T_regclass: "REGCLASS", + T_regtype: "REGTYPE", + T__regprocedure: "_REGPROCEDURE", + T__regoper: "_REGOPER", + T__regoperator: "_REGOPERATOR", + T__regclass: "_REGCLASS", + T__regtype: "_REGTYPE", + T_record: "RECORD", + T_cstring: "CSTRING", + T_any: "ANY", + T_anyarray: "ANYARRAY", + T_void: "VOID", + T_trigger: "TRIGGER", + T_language_handler: "LANGUAGE_HANDLER", + T_internal: "INTERNAL", + T_opaque: "OPAQUE", + T_anyelement: "ANYELEMENT", + T__record: "_RECORD", + T_anynonarray: "ANYNONARRAY", + T_pg_authid: "PG_AUTHID", + T_pg_auth_members: "PG_AUTH_MEMBERS", + T__txid_snapshot: "_TXID_SNAPSHOT", + T_uuid: "UUID", + T__uuid: "_UUID", + T_txid_snapshot: "TXID_SNAPSHOT", + T_fdw_handler: "FDW_HANDLER", + T_pg_lsn: "PG_LSN", + T__pg_lsn: "_PG_LSN", + T_tsm_handler: "TSM_HANDLER", + T_anyenum: "ANYENUM", + T_tsvector: "TSVECTOR", + T_tsquery: "TSQUERY", + T_gtsvector: "GTSVECTOR", + T__tsvector: "_TSVECTOR", + T__gtsvector: "_GTSVECTOR", + T__tsquery: "_TSQUERY", + T_regconfig: "REGCONFIG", + T__regconfig: "_REGCONFIG", + T_regdictionary: "REGDICTIONARY", + T__regdictionary: "_REGDICTIONARY", + T_jsonb: "JSONB", + T__jsonb: "_JSONB", + T_anyrange: "ANYRANGE", + T_event_trigger: "EVENT_TRIGGER", + T_int4range: "INT4RANGE", + T__int4range: "_INT4RANGE", + T_numrange: "NUMRANGE", + T__numrange: "_NUMRANGE", + T_tsrange: "TSRANGE", + T__tsrange: "_TSRANGE", + T_tstzrange: "TSTZRANGE", + T__tstzrange: "_TSTZRANGE", + T_daterange: "DATERANGE", + T__daterange: "_DATERANGE", + T_int8range: "INT8RANGE", + T__int8range: "_INT8RANGE", + T_pg_shseclabel: "PG_SHSECLABEL", + T_regnamespace: "REGNAMESPACE", + T__regnamespace: "_REGNAMESPACE", + T_regrole: "REGROLE", + T__regrole: "_REGROLE", +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/README.md golang-github-lib-pq-1.3.0/README.md --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/README.md 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/README.md 2019-12-11 04:41:10.000000000 +0000 @@ -1,6 +1,7 @@ # pq - A pure Go postgres driver for Go's database/sql package -[![Build Status](https://travis-ci.org/lib/pq.png?branch=master)](https://travis-ci.org/lib/pq) +[![GoDoc](https://godoc.org/github.com/lib/pq?status.svg)](https://godoc.org/github.com/lib/pq) +[![Build Status](https://travis-ci.org/lib/pq.svg?branch=master)](https://travis-ci.org/lib/pq) ## Install @@ -9,22 +10,11 @@ ## Docs For detailed documentation and basic usage examples, please see the package -documentation at . +documentation at . ## Tests -`go test` is used for testing. A running PostgreSQL server is -required, with the ability to log in. The default database to connect -to test with is "pqgotest," but it can be overridden using environment -variables. - -Example: - - PGHOST=/var/run/postgresql go test github.com/lib/pq - -Optionally, a benchmark suite can be run as part of the tests: - - PGHOST=/var/run/postgresql go test -bench . +`go test` is used for testing. See [TESTS.md](TESTS.md) for more details. ## Features @@ -38,6 +28,7 @@ * Many libpq compatible environment variables * Unix socket support * Notifications: `LISTEN`/`NOTIFY` +* pgpass support ## Future / Things you can help with @@ -67,6 +58,7 @@ * Everyone at The Go Team * Evan Shaw (edsrzf) * Ewan Chou (coocood) +* Fazal Majid (fazalmajid) * Federico Romero (federomero) * Fumin (fumin) * Gary Burd (garyburd) @@ -83,7 +75,7 @@ * Keith Rarick (kr) * Kir Shatrov (kirs) * Lann Martin (lann) -* Maciek Sakrejda (deafbybeheading) +* Maciek Sakrejda (uhoh-itsmaciek) * Marc Brinkmann (mbr) * Marko Tiikkaja (johto) * Matt Newberry (MattNewberry) diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/rows.go golang-github-lib-pq-1.3.0/rows.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/rows.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/rows.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,93 @@ +package pq + +import ( + "math" + "reflect" + "time" + + "github.com/lib/pq/oid" +) + +const headerSize = 4 + +type fieldDesc struct { + // The object ID of the data type. + OID oid.Oid + // The data type size (see pg_type.typlen). + // Note that negative values denote variable-width types. + Len int + // The type modifier (see pg_attribute.atttypmod). + // The meaning of the modifier is type-specific. + Mod int +} + +func (fd fieldDesc) Type() reflect.Type { + switch fd.OID { + case oid.T_int8: + return reflect.TypeOf(int64(0)) + case oid.T_int4: + return reflect.TypeOf(int32(0)) + case oid.T_int2: + return reflect.TypeOf(int16(0)) + case oid.T_varchar, oid.T_text: + return reflect.TypeOf("") + case oid.T_bool: + return reflect.TypeOf(false) + case oid.T_date, oid.T_time, oid.T_timetz, oid.T_timestamp, oid.T_timestamptz: + return reflect.TypeOf(time.Time{}) + case oid.T_bytea: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf(new(interface{})).Elem() + } +} + +func (fd fieldDesc) Name() string { + return oid.TypeName[fd.OID] +} + +func (fd fieldDesc) Length() (length int64, ok bool) { + switch fd.OID { + case oid.T_text, oid.T_bytea: + return math.MaxInt64, true + case oid.T_varchar, oid.T_bpchar: + return int64(fd.Mod - headerSize), true + default: + return 0, false + } +} + +func (fd fieldDesc) PrecisionScale() (precision, scale int64, ok bool) { + switch fd.OID { + case oid.T_numeric, oid.T__numeric: + mod := fd.Mod - headerSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (rs *rows) ColumnTypeScanType(index int) reflect.Type { + return rs.colTyps[index].Type() +} + +// ColumnTypeDatabaseTypeName return the database system type name. +func (rs *rows) ColumnTypeDatabaseTypeName(index int) string { + return rs.colTyps[index].Name() +} + +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (rs *rows) ColumnTypeLength(index int) (length int64, ok bool) { + return rs.colTyps[index].Length() +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (rs *rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + return rs.colTyps[index].PrecisionScale() +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/rows_test.go golang-github-lib-pq-1.3.0/rows_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/rows_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/rows_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,218 @@ +package pq + +import ( + "math" + "reflect" + "testing" + + "github.com/lib/pq/oid" +) + +func TestDataTypeName(t *testing.T) { + tts := []struct { + typ oid.Oid + name string + }{ + {oid.T_int8, "INT8"}, + {oid.T_int4, "INT4"}, + {oid.T_int2, "INT2"}, + {oid.T_varchar, "VARCHAR"}, + {oid.T_text, "TEXT"}, + {oid.T_bool, "BOOL"}, + {oid.T_numeric, "NUMERIC"}, + {oid.T_date, "DATE"}, + {oid.T_time, "TIME"}, + {oid.T_timetz, "TIMETZ"}, + {oid.T_timestamp, "TIMESTAMP"}, + {oid.T_timestamptz, "TIMESTAMPTZ"}, + {oid.T_bytea, "BYTEA"}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ} + if name := dt.Name(); name != tt.name { + t.Errorf("(%d) got: %s want: %s", i, name, tt.name) + } + } +} + +func TestDataType(t *testing.T) { + tts := []struct { + typ oid.Oid + kind reflect.Kind + }{ + {oid.T_int8, reflect.Int64}, + {oid.T_int4, reflect.Int32}, + {oid.T_int2, reflect.Int16}, + {oid.T_varchar, reflect.String}, + {oid.T_text, reflect.String}, + {oid.T_bool, reflect.Bool}, + {oid.T_date, reflect.Struct}, + {oid.T_time, reflect.Struct}, + {oid.T_timetz, reflect.Struct}, + {oid.T_timestamp, reflect.Struct}, + {oid.T_timestamptz, reflect.Struct}, + {oid.T_bytea, reflect.Slice}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ} + if kind := dt.Type().Kind(); kind != tt.kind { + t.Errorf("(%d) got: %s want: %s", i, kind, tt.kind) + } + } +} + +func TestDataTypeLength(t *testing.T) { + tts := []struct { + typ oid.Oid + len int + mod int + length int64 + ok bool + }{ + {oid.T_int4, 0, -1, 0, false}, + {oid.T_varchar, 65535, 9, 5, true}, + {oid.T_text, 65535, -1, math.MaxInt64, true}, + {oid.T_bytea, 65535, -1, math.MaxInt64, true}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ, Len: tt.len, Mod: tt.mod} + if l, k := dt.Length(); k != tt.ok || l != tt.length { + t.Errorf("(%d) got: %d, %t want: %d, %t", i, l, k, tt.length, tt.ok) + } + } +} + +func TestDataTypePrecisionScale(t *testing.T) { + tts := []struct { + typ oid.Oid + mod int + precision, scale int64 + ok bool + }{ + {oid.T_int4, -1, 0, 0, false}, + {oid.T_numeric, 589830, 9, 2, true}, + {oid.T_text, -1, 0, 0, false}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ, Mod: tt.mod} + p, s, k := dt.PrecisionScale() + if k != tt.ok { + t.Errorf("(%d) got: %t want: %t", i, k, tt.ok) + } + if p != tt.precision { + t.Errorf("(%d) wrong precision got: %d want: %d", i, p, tt.precision) + } + if s != tt.scale { + t.Errorf("(%d) wrong scale got: %d want: %d", i, s, tt.scale) + } + } +} + +func TestRowsColumnTypes(t *testing.T) { + columnTypesTests := []struct { + Name string + TypeName string + Length struct { + Len int64 + OK bool + } + DecimalSize struct { + Precision int64 + Scale int64 + OK bool + } + ScanType reflect.Type + }{ + { + Name: "a", + TypeName: "INT4", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(int32(0)), + }, { + Name: "bar", + TypeName: "TEXT", + Length: struct { + Len int64 + OK bool + }{ + Len: math.MaxInt64, + OK: true, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(""), + }, + } + + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") + if err != nil { + t.Fatal(err) + } + + columns, err := rows.ColumnTypes() + if err != nil { + t.Fatal(err) + } + if len(columns) != 3 { + t.Errorf("expected 3 columns found %d", len(columns)) + } + + for i, tt := range columnTypesTests { + c := columns[i] + if c.Name() != tt.Name { + t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) + } + if c.DatabaseTypeName() != tt.TypeName { + t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) + } + l, ok := c.Length() + if l != tt.Length.Len { + t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) + } + if ok != tt.Length.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) + } + p, s, ok := c.DecimalSize() + if p != tt.DecimalSize.Precision { + t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) + } + if s != tt.DecimalSize.Scale { + t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) + } + if ok != tt.DecimalSize.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) + } + if c.ScanType() != tt.ScanType { + t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) + } + } +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/scram/scram.go golang-github-lib-pq-1.3.0/scram/scram.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/scram/scram.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/scram/scram.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,264 @@ +// Copyright (c) 2014 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package scram implements a SCRAM-{SHA-1,etc} client per RFC5802. +// +// http://tools.ietf.org/html/rfc5802 +// +package scram + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "encoding/base64" + "fmt" + "hash" + "strconv" + "strings" +) + +// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc). +// +// A Client may be used within a SASL conversation with logic resembling: +// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } +// +type Client struct { + newHash func() hash.Hash + + user string + pass string + step int + out bytes.Buffer + err error + + clientNonce []byte + serverNonce []byte + saltedPass []byte + authMsg bytes.Buffer +} + +// NewClient returns a new SCRAM-* client with the provided hash algorithm. +// +// For SCRAM-SHA-256, for example, use: +// +// client := scram.NewClient(sha256.New, user, pass) +// +func NewClient(newHash func() hash.Hash, user, pass string) *Client { + c := &Client{ + newHash: newHash, + user: user, + pass: pass, + } + c.out.Grow(256) + c.authMsg.Grow(256) + return c +} + +// Out returns the data to be sent to the server in the current step. +func (c *Client) Out() []byte { + if c.out.Len() == 0 { + return nil + } + return c.out.Bytes() +} + +// Err returns the error that occurred, or nil if there were no errors. +func (c *Client) Err() error { + return c.err +} + +// SetNonce sets the client nonce to the provided value. +// If not set, the nonce is generated automatically out of crypto/rand on the first step. +func (c *Client) SetNonce(nonce []byte) { + c.clientNonce = nonce +} + +var escaper = strings.NewReplacer("=", "=3D", ",", "=2C") + +// Step processes the incoming data from the server and makes the +// next round of data for the server available via Client.Out. +// Step returns false if there are no errors and more data is +// still expected. +func (c *Client) Step(in []byte) bool { + c.out.Reset() + if c.step > 2 || c.err != nil { + return false + } + c.step++ + switch c.step { + case 1: + c.err = c.step1(in) + case 2: + c.err = c.step2(in) + case 3: + c.err = c.step3(in) + } + return c.step > 2 || c.err != nil +} + +func (c *Client) step1(in []byte) error { + if len(c.clientNonce) == 0 { + const nonceLen = 16 + buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen)) + if _, err := rand.Read(buf[:nonceLen]); err != nil { + return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err) + } + c.clientNonce = buf[nonceLen:] + b64.Encode(c.clientNonce, buf[:nonceLen]) + } + c.authMsg.WriteString("n=") + escaper.WriteString(&c.authMsg, c.user) + c.authMsg.WriteString(",r=") + c.authMsg.Write(c.clientNonce) + + c.out.WriteString("n,,") + c.out.Write(c.authMsg.Bytes()) + return nil +} + +var b64 = base64.StdEncoding + +func (c *Client) step2(in []byte) error { + c.authMsg.WriteByte(',') + c.authMsg.Write(in) + + fields := bytes.Split(in, []byte(",")) + if len(fields) != 3 { + return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in) + } + if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0]) + } + if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1]) + } + if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) + } + + c.serverNonce = fields[0][2:] + if !bytes.HasPrefix(c.serverNonce, c.clientNonce) { + return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce) + } + + salt := make([]byte, b64.DecodedLen(len(fields[1][2:]))) + n, err := b64.Decode(salt, fields[1][2:]) + if err != nil { + return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1]) + } + salt = salt[:n] + iterCount, err := strconv.Atoi(string(fields[2][2:])) + if err != nil { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) + } + c.saltPassword(salt, iterCount) + + c.authMsg.WriteString(",c=biws,r=") + c.authMsg.Write(c.serverNonce) + + c.out.WriteString("c=biws,r=") + c.out.Write(c.serverNonce) + c.out.WriteString(",p=") + c.out.Write(c.clientProof()) + return nil +} + +func (c *Client) step3(in []byte) error { + var isv, ise bool + var fields = bytes.Split(in, []byte(",")) + if len(fields) == 1 { + isv = bytes.HasPrefix(fields[0], []byte("v=")) + ise = bytes.HasPrefix(fields[0], []byte("e=")) + } + if ise { + return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:]) + } else if !isv { + return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in) + } + if !bytes.Equal(c.serverSignature(), fields[0][2:]) { + return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:]) + } + return nil +} + +func (c *Client) saltPassword(salt []byte, iterCount int) { + mac := hmac.New(c.newHash, []byte(c.pass)) + mac.Write(salt) + mac.Write([]byte{0, 0, 0, 1}) + ui := mac.Sum(nil) + hi := make([]byte, len(ui)) + copy(hi, ui) + for i := 1; i < iterCount; i++ { + mac.Reset() + mac.Write(ui) + mac.Sum(ui[:0]) + for j, b := range ui { + hi[j] ^= b + } + } + c.saltedPass = hi +} + +func (c *Client) clientProof() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Client Key")) + clientKey := mac.Sum(nil) + hash := c.newHash() + hash.Write(clientKey) + storedKey := hash.Sum(nil) + mac = hmac.New(c.newHash, storedKey) + mac.Write(c.authMsg.Bytes()) + clientProof := mac.Sum(nil) + for i, b := range clientKey { + clientProof[i] ^= b + } + clientProof64 := make([]byte, b64.EncodedLen(len(clientProof))) + b64.Encode(clientProof64, clientProof) + return clientProof64 +} + +func (c *Client) serverSignature() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Server Key")) + serverKey := mac.Sum(nil) + + mac = hmac.New(c.newHash, serverKey) + mac.Write(c.authMsg.Bytes()) + serverSignature := mac.Sum(nil) + + encoded := make([]byte, b64.EncodedLen(len(serverSignature))) + b64.Encode(encoded, serverSignature) + return encoded +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/ssl.go golang-github-lib-pq-1.3.0/ssl.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/ssl.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/ssl.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,175 @@ +package pq + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "net" + "os" + "os/user" + "path/filepath" +) + +// ssl generates a function to upgrade a net.Conn based on the "sslmode" and +// related settings. The function is nil when no upgrade should take place. +func ssl(o values) (func(net.Conn) (net.Conn, error), error) { + verifyCaOnly := false + tlsConf := tls.Config{} + switch mode := o["sslmode"]; mode { + // "require" is the default. + case "", "require": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + + // From http://www.postgresql.org/docs/current/static/libpq-ssl.html: + // + // Note: For backwards compatibility with earlier versions of + // PostgreSQL, if a root CA file exists, the behavior of + // sslmode=require will be the same as that of verify-ca, meaning the + // server certificate is validated against the CA. Relying on this + // behavior is discouraged, and applications that need certificate + // validation should always use verify-ca or verify-full. + if sslrootcert, ok := o["sslrootcert"]; ok { + if _, err := os.Stat(sslrootcert); err == nil { + verifyCaOnly = true + } else { + delete(o, "sslrootcert") + } + } + case "verify-ca": + // We must skip TLS's own verification since it requires full + // verification since Go 1.3. + tlsConf.InsecureSkipVerify = true + verifyCaOnly = true + case "verify-full": + tlsConf.ServerName = o["host"] + case "disable": + return nil, nil + default: + return nil, fmterrorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode) + } + + err := sslClientCertificates(&tlsConf, o) + if err != nil { + return nil, err + } + err = sslCertificateAuthority(&tlsConf, o) + if err != nil { + return nil, err + } + + // Accept renegotiation requests initiated by the backend. + // + // Renegotiation was deprecated then removed from PostgreSQL 9.5, but + // the default configuration of older versions has it enabled. Redshift + // also initiates renegotiations and cannot be reconfigured. + tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient + + return func(conn net.Conn) (net.Conn, error) { + client := tls.Client(conn, &tlsConf) + if verifyCaOnly { + err := sslVerifyCertificateAuthority(client, &tlsConf) + if err != nil { + return nil, err + } + } + return client, nil + }, nil +} + +// sslClientCertificates adds the certificate specified in the "sslcert" and +// "sslkey" settings, or if they aren't set, from the .postgresql directory +// in the user's home directory. The configured files must exist and have +// the correct permissions. +func sslClientCertificates(tlsConf *tls.Config, o values) error { + // user.Current() might fail when cross-compiling. We have to ignore the + // error and continue without home directory defaults, since we wouldn't + // know from where to load them. + user, _ := user.Current() + + // In libpq, the client certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037 + sslcert := o["sslcert"] + if len(sslcert) == 0 && user != nil { + sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt") + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1045 + if len(sslcert) == 0 { + return nil + } + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1050:L1054 + if _, err := os.Stat(sslcert); os.IsNotExist(err) { + return nil + } else if err != nil { + return err + } + + // In libpq, the ssl key is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1123-L1222 + sslkey := o["sslkey"] + if len(sslkey) == 0 && user != nil { + sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key") + } + + if len(sslkey) > 0 { + if err := sslKeyPermissions(sslkey); err != nil { + return err + } + } + + cert, err := tls.LoadX509KeyPair(sslcert, sslkey) + if err != nil { + return err + } + + tlsConf.Certificates = []tls.Certificate{cert} + return nil +} + +// sslCertificateAuthority adds the RootCA specified in the "sslrootcert" setting. +func sslCertificateAuthority(tlsConf *tls.Config, o values) error { + // In libpq, the root certificate is only loaded if the setting is not blank. + // + // https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L950-L951 + if sslrootcert := o["sslrootcert"]; len(sslrootcert) > 0 { + tlsConf.RootCAs = x509.NewCertPool() + + cert, err := ioutil.ReadFile(sslrootcert) + if err != nil { + return err + } + + if !tlsConf.RootCAs.AppendCertsFromPEM(cert) { + return fmterrorf("couldn't parse pem in sslrootcert") + } + } + + return nil +} + +// sslVerifyCertificateAuthority carries out a TLS handshake to the server and +// verifies the presented certificate against the CA, i.e. the one specified in +// sslrootcert or the system CA if sslrootcert was not specified. +func sslVerifyCertificateAuthority(client *tls.Conn, tlsConf *tls.Config) error { + err := client.Handshake() + if err != nil { + return err + } + certs := client.ConnectionState().PeerCertificates + opts := x509.VerifyOptions{ + DNSName: client.ConnectionState().ServerName, + Intermediates: x509.NewCertPool(), + Roots: tlsConf.RootCAs, + } + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + _, err = certs[0].Verify(opts) + return err +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/ssl_permissions.go golang-github-lib-pq-1.3.0/ssl_permissions.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/ssl_permissions.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/ssl_permissions.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,20 @@ +// +build !windows + +package pq + +import "os" + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(sslkey string) error { + info, err := os.Stat(sslkey) + if err != nil { + return err + } + if info.Mode().Perm()&0077 != 0 { + return ErrSSLKeyHasWorldPermissions + } + return nil +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/ssl_test.go golang-github-lib-pq-1.3.0/ssl_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/ssl_test.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/ssl_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -6,7 +6,6 @@ _ "crypto/sha256" "crypto/x509" "database/sql" - "fmt" "os" "path/filepath" "testing" @@ -42,10 +41,13 @@ } func checkSSLSetup(t *testing.T, conninfo string) { - db, err := openSSLConn(t, conninfo) - if err == nil { - db.Close() - t.Fatalf("expected error with conninfo=%q", conninfo) + _, err := openSSLConn(t, conninfo) + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) } } @@ -100,127 +102,178 @@ } } -// Test sslmode=verify-ca -func TestSSLVerifyCA(t *testing.T) { +// Test sslmode=require sslrootcert=rootCertPath +func TestSSLRequireWithRootCert(t *testing.T) { maybeSkipSSLTests(t) // Environment sanity check: should fail without SSL checkSSLSetup(t, "sslmode=disable user=pqgossltest") - // Not OK according to the system CA - _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") + bogusRootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "bogus_root.crt") + bogusRootCert := "sslrootcert=" + bogusRootCertPath + " " + + // Not OK according to the bogus CA + _, err := openSSLConn(t, bogusRootCert+"host=postgres sslmode=require user=pqgossltest") if err == nil { t.Fatal("expected error") } _, ok := err.(x509.UnknownAuthorityError) if !ok { - t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) + t.Fatalf("expected x509.UnknownAuthorityError, got %s, %#+v", err, err) + } + + nonExistentCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "non_existent.crt") + nonExistentCert := "sslrootcert=" + nonExistentCertPath + " " + + // No match on Common Name, but that's OK because we're not validating anything. + _, err = openSSLConn(t, nonExistentCert+"host=127.0.0.1 sslmode=require user=pqgossltest") + if err != nil { + t.Fatal(err) } rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") rootCert := "sslrootcert=" + rootCertPath + " " - // No match on Common Name, but that's OK - _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest") + + // No match on Common Name, but that's OK because we're not validating the CN. + _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=require user=pqgossltest") if err != nil { t.Fatal(err) } // Everything OK - _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest") + _, err = openSSLConn(t, rootCert+"host=postgres sslmode=require user=pqgossltest") if err != nil { t.Fatal(err) } } -func getCertConninfo(t *testing.T, source string) string { - var sslkey string - var sslcert string - - certpath := os.Getenv("PQSSLCERTTEST_PATH") - - switch source { - case "missingkey": - sslkey = "/tmp/filedoesnotexist" - sslcert = filepath.Join(certpath, "postgresql.crt") - case "missingcert": - sslkey = filepath.Join(certpath, "postgresql.key") - sslcert = "/tmp/filedoesnotexist" - case "certtwice": - sslkey = filepath.Join(certpath, "postgresql.crt") - sslcert = filepath.Join(certpath, "postgresql.crt") - case "valid": - sslkey = filepath.Join(certpath, "postgresql.key") - sslcert = filepath.Join(certpath, "postgresql.crt") - default: - t.Fatalf("invalid source %q", source) - } - return fmt.Sprintf("sslmode=require user=pqgosslcert sslkey=%s sslcert=%s", sslkey, sslcert) -} - -// Authenticate over SSL using client certificates -func TestSSLClientCertificates(t *testing.T) { +// Test sslmode=verify-ca +func TestSSLVerifyCA(t *testing.T) { maybeSkipSSLTests(t) // Environment sanity check: should fail without SSL checkSSLSetup(t, "sslmode=disable user=pqgossltest") - // Should also fail without a valid certificate - db, err := openSSLConn(t, "sslmode=require user=pqgosslcert") - if err == nil { - db.Close() - t.Fatal("expected error") - } - pge, ok := err.(*Error) - if !ok { - t.Fatal("expected pq.Error") + // Not OK according to the system CA + { + _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") + if _, ok := err.(x509.UnknownAuthorityError); !ok { + t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) + } } - if pge.Code.Name() != "invalid_authorization_specification" { - t.Fatalf("unexpected error code %q", pge.Code.Name()) + + // Still not OK according to the system CA; empty sslrootcert is treated as unspecified. + { + _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest sslrootcert=''") + if _, ok := err.(x509.UnknownAuthorityError); !ok { + t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) + } } - // Should work - db, err = openSSLConn(t, getCertConninfo(t, "valid")) - if err != nil { + rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") + rootCert := "sslrootcert=" + rootCertPath + " " + // No match on Common Name, but that's OK + if _, err := openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest"); err != nil { t.Fatal(err) } - rows, err := db.Query("SELECT 1") - if err != nil { + // Everything OK + if _, err := openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest"); err != nil { t.Fatal(err) } - rows.Close() } -// Test errors with ssl certificates -func TestSSLClientCertificatesMissingFiles(t *testing.T) { +// Authenticate over SSL using client certificates +func TestSSLClientCertificates(t *testing.T) { maybeSkipSSLTests(t) // Environment sanity check: should fail without SSL checkSSLSetup(t, "sslmode=disable user=pqgossltest") - // Key missing, should fail - _, err := openSSLConn(t, getCertConninfo(t, "missingkey")) - if err == nil { - t.Fatal("expected error") + const baseinfo = "sslmode=require user=pqgosslcert" + + // Certificate not specified, should fail + { + _, err := openSSLConn(t, baseinfo) + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } + } + + // Empty certificate specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert=''") + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } + } + + // Non-existent certificate specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert=/tmp/filedoesnotexist") + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } } - // should be a PathError - _, ok := err.(*os.PathError) + + certpath, ok := os.LookupEnv("PQSSLCERTTEST_PATH") if !ok { - t.Fatalf("expected PathError, got %#+v", err) + t.Fatalf("PQSSLCERTTEST_PATH not present in environment") } - // Cert missing, should fail - _, err = openSSLConn(t, getCertConninfo(t, "missingcert")) - if err == nil { - t.Fatal("expected error") + sslcert := filepath.Join(certpath, "postgresql.crt") + + // Cert present, key not specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert) + if _, ok := err.(*os.PathError); !ok { + t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) + } } - // should be a PathError - _, ok = err.(*os.PathError) - if !ok { - t.Fatalf("expected PathError, got %#+v", err) + + // Cert present, empty key specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=''") + if _, ok := err.(*os.PathError); !ok { + t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) + } } - // Key has wrong permissions, should fail - _, err = openSSLConn(t, getCertConninfo(t, "certtwice")) - if err == nil { - t.Fatal("expected error") + // Cert present, non-existent key, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=/tmp/filedoesnotexist") + if _, ok := err.(*os.PathError); !ok { + t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) + } + } + + // Key has wrong permissions (passing the cert as the key), should fail + if _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslcert); err != ErrSSLKeyHasWorldPermissions { + t.Fatalf("expected %s, got %#+v", ErrSSLKeyHasWorldPermissions, err) } - if err != ErrSSLKeyHasWorldPermissions { - t.Fatalf("expected ErrSSLKeyHasWorldPermissions, got %#+v", err) + + sslkey := filepath.Join(certpath, "postgresql.key") + + // Should work + if db, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslkey); err != nil { + t.Fatal(err) + } else { + rows, err := db.Query("SELECT 1") + if err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } } } diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/ssl_windows.go golang-github-lib-pq-1.3.0/ssl_windows.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/ssl_windows.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/ssl_windows.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,9 @@ +// +build windows + +package pq + +// sslKeyPermissions checks the permissions on user-supplied ssl key files. +// The key file should have very little access. +// +// libpq does not check key file permissions on Windows. +func sslKeyPermissions(string) error { return nil } diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/TESTS.md golang-github-lib-pq-1.3.0/TESTS.md --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/TESTS.md 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/TESTS.md 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,33 @@ +# Tests + +## Running Tests + +`go test` is used for testing. A running PostgreSQL +server is required, with the ability to log in. The +database to connect to test with is "pqgotest," on +"localhost" but these can be overridden using [environment +variables](https://www.postgresql.org/docs/9.3/static/libpq-envars.html). + +Example: + + PGHOST=/run/postgresql go test + +## Benchmarks + +A benchmark suite can be run as part of the tests: + + go test -bench . + +## Example setup (Docker) + +Run a postgres container: + +``` +docker run --expose 5432:5432 postgres +``` + +Run tests: + +``` +PGHOST=localhost PGPORT=5432 PGUSER=postgres PGSSLMODE=disable PGDATABASE=postgres go test +``` diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/.travis.sh golang-github-lib-pq-1.3.0/.travis.sh --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/.travis.sh 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/.travis.sh 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,73 @@ +#!/bin/bash + +set -eu + +client_configure() { + sudo chmod 600 $PQSSLCERTTEST_PATH/postgresql.key +} + +pgdg_repository() { + local sourcelist='sources.list.d/postgresql.list' + + curl -sS 'https://www.postgresql.org/media/keys/ACCC4CF8.asc' | sudo apt-key add - + echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION | sudo tee "/etc/apt/$sourcelist" + sudo apt-get -o Dir::Etc::sourcelist="$sourcelist" -o Dir::Etc::sourceparts='-' -o APT::Get::List-Cleanup='0' update +} + +postgresql_configure() { + sudo tee /etc/postgresql/$PGVERSION/main/pg_hba.conf > /dev/null <<-config + local all all trust + hostnossl all pqgossltest 127.0.0.1/32 reject + hostnossl all pqgosslcert 127.0.0.1/32 reject + hostssl all pqgossltest 127.0.0.1/32 trust + hostssl all pqgosslcert 127.0.0.1/32 cert + host all all 127.0.0.1/32 trust + hostnossl all pqgossltest ::1/128 reject + hostnossl all pqgosslcert ::1/128 reject + hostssl all pqgossltest ::1/128 trust + hostssl all pqgosslcert ::1/128 cert + host all all ::1/128 trust + config + + xargs sudo install -o postgres -g postgres -m 600 -t /var/lib/postgresql/$PGVERSION/main/ <<-certificates + certs/root.crt + certs/server.crt + certs/server.key + certificates + + sort -VCu <<-versions || + $PGVERSION + 9.2 + versions + sudo tee -a /etc/postgresql/$PGVERSION/main/postgresql.conf > /dev/null <<-config + ssl_ca_file = 'root.crt' + ssl_cert_file = 'server.crt' + ssl_key_file = 'server.key' + config + + echo 127.0.0.1 postgres | sudo tee -a /etc/hosts > /dev/null + + sudo service postgresql restart +} + +postgresql_install() { + xargs sudo apt-get -y -o Dpkg::Options::='--force-confdef' -o Dpkg::Options::='--force-confnew' install <<-packages + postgresql-$PGVERSION + postgresql-server-dev-$PGVERSION + postgresql-contrib-$PGVERSION + packages +} + +postgresql_uninstall() { + sudo service postgresql stop + xargs sudo apt-get -y --purge remove <<-packages + libpq-dev + libpq5 + postgresql + postgresql-client-common + postgresql-common + packages + sudo rm -rf /var/lib/postgresql +} + +$1 diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/.travis.yml golang-github-lib-pq-1.3.0/.travis.yml --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/.travis.yml 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/.travis.yml 2019-12-11 04:41:10.000000000 +0000 @@ -1,44 +1,11 @@ language: go go: - - 1.1 - - 1.2 - - 1.3 - - 1.4 - - 1.5 - - tip + - 1.11.x + - 1.12.x + - master -before_install: - - psql --version - - sudo /etc/init.d/postgresql stop - - sudo apt-get -y --purge remove postgresql libpq-dev libpq5 postgresql-client-common postgresql-common - - sudo rm -rf /var/lib/postgresql - - wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | sudo apt-key add - - - sudo sh -c "echo deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main $PGVERSION >> /etc/apt/sources.list.d/postgresql.list" - - sudo apt-get update -qq - - sudo apt-get -y -o Dpkg::Options::=--force-confdef -o Dpkg::Options::="--force-confnew" install postgresql-$PGVERSION postgresql-server-dev-$PGVERSION postgresql-contrib-$PGVERSION - - sudo chmod 777 /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "local all postgres trust" > /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "local all all trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgossltest 127.0.0.1/32 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgosslcert 127.0.0.1/32 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgossltest 127.0.0.1/32 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgosslcert 127.0.0.1/32 cert" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "host all all 127.0.0.1/32 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgossltest ::1/128 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostnossl all pqgosslcert ::1/128 reject" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgossltest ::1/128 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "hostssl all pqgosslcert ::1/128 cert" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - echo "host all all ::1/128 trust" >> /etc/postgresql/$PGVERSION/main/pg_hba.conf - - sudo install -o postgres -g postgres -m 600 -t /var/lib/postgresql/$PGVERSION/main/ certs/server.key certs/server.crt certs/root.crt - - sudo bash -c "[[ '${PGVERSION}' < '9.2' ]] || (echo \"ssl_cert_file = 'server.crt'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf)" - - sudo bash -c "[[ '${PGVERSION}' < '9.2' ]] || (echo \"ssl_key_file = 'server.key'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf)" - - sudo bash -c "[[ '${PGVERSION}' < '9.2' ]] || (echo \"ssl_ca_file = 'root.crt'\" >> /etc/postgresql/$PGVERSION/main/postgresql.conf)" - - sudo sh -c "echo 127.0.0.1 postgres >> /etc/hosts" - - sudo ls -l /var/lib/postgresql/$PGVERSION/main/ - - sudo cat /etc/postgresql/$PGVERSION/main/postgresql.conf - - sudo chmod 600 $PQSSLCERTTEST_PATH/postgresql.key - - sudo /etc/init.d/postgresql restart +sudo: true env: global: @@ -47,23 +14,31 @@ - PQSSLCERTTEST_PATH=$PWD/certs - PGHOST=127.0.0.1 matrix: - - PGVERSION=9.4 PQTEST_BINARY_PARAMETERS=yes - - PGVERSION=9.3 PQTEST_BINARY_PARAMETERS=yes - - PGVERSION=9.2 PQTEST_BINARY_PARAMETERS=yes - - PGVERSION=9.1 PQTEST_BINARY_PARAMETERS=yes - - PGVERSION=9.0 PQTEST_BINARY_PARAMETERS=yes - - PGVERSION=8.4 PQTEST_BINARY_PARAMETERS=yes - - PGVERSION=9.4 PQTEST_BINARY_PARAMETERS=no - - PGVERSION=9.3 PQTEST_BINARY_PARAMETERS=no - - PGVERSION=9.2 PQTEST_BINARY_PARAMETERS=no - - PGVERSION=9.1 PQTEST_BINARY_PARAMETERS=no - - PGVERSION=9.0 PQTEST_BINARY_PARAMETERS=no - - PGVERSION=8.4 PQTEST_BINARY_PARAMETERS=no + - PGVERSION=10 + - PGVERSION=9.6 + - PGVERSION=9.5 + - PGVERSION=9.4 -script: - - go test -v ./... +before_install: + - ./.travis.sh postgresql_uninstall + - ./.travis.sh pgdg_repository + - ./.travis.sh postgresql_install + - ./.travis.sh postgresql_configure + - ./.travis.sh client_configure + - go get golang.org/x/tools/cmd/goimports + - go get golang.org/x/lint/golint + - GO111MODULE=on go get honnef.co/go/tools/cmd/staticcheck@2019.2.1 before_script: - - psql -c 'create database pqgotest' -U postgres - - psql -c 'create user pqgossltest' -U postgres - - psql -c 'create user pqgosslcert' -U postgres + - createdb pqgotest + - createuser -DRS pqgossltest + - createuser -DRS pqgosslcert + +script: + - > + goimports -d -e $(find -name '*.go') | awk '{ print } END { exit NR == 0 ? 0 : 1 }' + - go vet ./... + - staticcheck -go 1.11 ./... + - golint ./... + - PQTEST_BINARY_PARAMETERS=no go test -race -v ./... + - PQTEST_BINARY_PARAMETERS=yes go test -race -v ./... diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/url.go golang-github-lib-pq-1.3.0/url.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/url.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/url.go 2019-12-11 04:41:10.000000000 +0000 @@ -2,6 +2,7 @@ import ( "fmt" + "net" nurl "net/url" "sort" "strings" @@ -54,12 +55,11 @@ accrue("password", v) } - i := strings.Index(u.Host, ":") - if i < 0 { + if host, port, err := net.SplitHostPort(u.Host); err != nil { accrue("host", u.Host) } else { - accrue("host", u.Host[:i]) - accrue("port", u.Host[i+1:]) + accrue("host", host) + accrue("port", port) } if u.Path != "" { diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/url_test.go golang-github-lib-pq-1.3.0/url_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/url_test.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/url_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -16,6 +16,18 @@ } } +func TestIPv6LoopbackParseURL(t *testing.T) { + expected := "host=::1 port=1234" + str, err := ParseURL("postgres://[::1]:1234") + if err != nil { + t.Fatal(err) + } + + if str != expected { + t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) + } +} + func TestFullParseURL(t *testing.T) { expected := `dbname=database host=hostname.remote password=top\ secret port=1234 user=username` str, err := ParseURL("postgres://username:top%20secret@hostname.remote:1234/database") diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/user_posix.go golang-github-lib-pq-1.3.0/user_posix.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/user_posix.go 2015-10-14 20:45:47.000000000 +0000 +++ golang-github-lib-pq-1.3.0/user_posix.go 2019-12-11 04:41:10.000000000 +0000 @@ -1,6 +1,6 @@ // Package pq is a pure Go Postgres driver for the database/sql package. -// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris +// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris rumprun package pq diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/uuid.go golang-github-lib-pq-1.3.0/uuid.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/uuid.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/uuid.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,23 @@ +package pq + +import ( + "encoding/hex" + "fmt" +) + +// decodeUUIDBinary interprets the binary format of a uuid, returning it in text format. +func decodeUUIDBinary(src []byte) ([]byte, error) { + if len(src) != 16 { + return nil, fmt.Errorf("pq: unable to decode uuid; bad length: %d", len(src)) + } + + dst := make([]byte, 36) + dst[8], dst[13], dst[18], dst[23] = '-', '-', '-', '-' + hex.Encode(dst[0:], src[0:4]) + hex.Encode(dst[9:], src[4:6]) + hex.Encode(dst[14:], src[6:8]) + hex.Encode(dst[19:], src[8:10]) + hex.Encode(dst[24:], src[10:16]) + + return dst, nil +} diff -Nru golang-github-lib-pq-0.0~git20151007.0.ffe986a/uuid_test.go golang-github-lib-pq-1.3.0/uuid_test.go --- golang-github-lib-pq-0.0~git20151007.0.ffe986a/uuid_test.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-github-lib-pq-1.3.0/uuid_test.go 2019-12-11 04:41:10.000000000 +0000 @@ -0,0 +1,46 @@ +package pq + +import ( + "reflect" + "strings" + "testing" +) + +func TestDecodeUUIDBinaryError(t *testing.T) { + t.Parallel() + _, err := decodeUUIDBinary([]byte{0x12, 0x34}) + + if err == nil { + t.Fatal("Expected error, got none") + } + if !strings.HasPrefix(err.Error(), "pq:") { + t.Errorf("Expected error to start with %q, got %q", "pq:", err.Error()) + } + if !strings.Contains(err.Error(), "bad length: 2") { + t.Errorf("Expected error to contain length, got %q", err.Error()) + } +} + +func BenchmarkDecodeUUIDBinary(b *testing.B) { + x := []byte{0x03, 0xa3, 0x52, 0x2f, 0x89, 0x28, 0x49, 0x87, 0x84, 0xd6, 0x93, 0x7b, 0x36, 0xec, 0x27, 0x6f} + + for i := 0; i < b.N; i++ { + decodeUUIDBinary(x) + } +} + +func TestDecodeUUIDBackend(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + var s = "a0ecc91d-a13f-4fe4-9fce-7e09777cc70a" + var scanned interface{} + + err := db.QueryRow(`SELECT $1::uuid`, s).Scan(&scanned) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(scanned, []byte(s)) { + t.Errorf("Expected []byte(%q), got %T(%q)", s, scanned, scanned) + } +}