diff -Nru golang-blitiri-go-spf-1.1.0/cmd/spf-check/.gitignore golang-blitiri-go-spf-1.3.0/cmd/spf-check/.gitignore --- golang-blitiri-go-spf-1.1.0/cmd/spf-check/.gitignore 1970-01-01 00:00:00.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/cmd/spf-check/.gitignore 2021-11-20 17:24:26.000000000 +0000 @@ -0,0 +1 @@ +spf-check diff -Nru golang-blitiri-go-spf-1.1.0/cmd/spf-check/spf-check.go golang-blitiri-go-spf-1.3.0/cmd/spf-check/spf-check.go --- golang-blitiri-go-spf-1.1.0/cmd/spf-check/spf-check.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/cmd/spf-check/spf-check.go 2021-11-20 17:24:26.000000000 +0000 @@ -0,0 +1,64 @@ +// +build ignore + +// Command line tool to perform SPF checks. +// +// For development and experimentation only. +// No backwards compatibility guarantees. +package main + +import ( + "context" + "flag" + "fmt" + "net" + "os" + + "blitiri.com.ar/go/spf" +) + +var ( + debug = flag.Bool("debug", false, "include debugging output") + dnsAddr = flag.String("dns_addr", "", "address of the DNS server to use") +) + +func main() { + flag.Usage = func() { + fmt.Printf("Usage: spf-check [options] 1.2.3.4 name@sender.com\n\n") + flag.PrintDefaults() + } + + flag.Parse() + args := flag.Args() + if len(args) < 2 { + flag.Usage() + os.Exit(1) + } + + opts := []spf.Option{} + if *debug { + traceF := func(f string, a ...interface{}) { + fmt.Printf("debug: "+f+"\n", a...) + } + opts = append(opts, spf.WithTraceFunc(traceF)) + } + + if *dnsAddr != "" { + dialFunc := func(ctx context.Context, network, addr string) (net.Conn, error) { + return (&net.Dialer{}).DialContext(ctx, network, *dnsAddr) + } + opts = append(opts, spf.WithResolver( + &net.Resolver{ + PreferGo: true, + Dial: dialFunc, + })) + } + + ip := net.ParseIP(args[0]) + sender := args[1] + fmt.Printf("Sender: %v\n", sender) + fmt.Printf("IP: %v\n", ip) + + r, err := spf.CheckHostWithSender(ip, "", sender, opts...) + fmt.Printf("Result: %v\n", r) + fmt.Printf("Error: %v\n", err) +} diff -Nru golang-blitiri-go-spf-1.1.0/debian/changelog golang-blitiri-go-spf-1.3.0/debian/changelog --- golang-blitiri-go-spf-1.1.0/debian/changelog 2020-06-28 10:56:36.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/debian/changelog 2021-11-20 17:41:27.000000000 +0000 @@ -1,3 +1,17 @@ +golang-blitiri-go-spf (1.3.0-1) unstable; urgency=medium + + * New upstream release (1.3.0) + * Standards-Version: 4.6.0 (no changes) + + -- Alberto Bertogli Sat, 20 Nov 2021 17:41:27 +0000 + +golang-blitiri-go-spf (1.2.0-1) unstable; urgency=medium + + * New upstream release (1.2.0) + * Standards-Version 4.5.1 (no changes) + + -- Alberto Bertogli Sun, 09 May 2021 20:00:13 +0100 + golang-blitiri-go-spf (1.1.0-1) unstable; urgency=medium * New upstream release (1.1.0) diff -Nru golang-blitiri-go-spf-1.1.0/debian/control golang-blitiri-go-spf-1.3.0/debian/control --- golang-blitiri-go-spf-1.1.0/debian/control 2020-06-28 10:56:36.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/debian/control 2021-11-20 17:41:27.000000000 +0000 @@ -7,7 +7,7 @@ dh-golang, golang-any, golang-gopkg-yaml.v2-dev, -Standards-Version: 4.5.0 +Standards-Version: 4.6.0 Rules-Requires-Root: no Homepage: https://blitiri.com.ar/git/r/spf/ Vcs-Browser: https://salsa.debian.org/go-team/packages/golang-blitiri-go-spf @@ -19,6 +19,7 @@ Architecture: all Depends: ${misc:Depends}, ${shlibs:Depends}, +Multi-Arch: foreign Description: SPF (Sender Policy Framework) implementation in Go blitiri.com.ar/go/spf is an open source implementation of the Sender Policy Framework (SPF) in Go. diff -Nru golang-blitiri-go-spf-1.1.0/debian/rules golang-blitiri-go-spf-1.3.0/debian/rules --- golang-blitiri-go-spf-1.1.0/debian/rules 2020-06-28 10:56:36.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/debian/rules 2021-11-20 17:41:27.000000000 +0000 @@ -1,4 +1,6 @@ #!/usr/bin/make -f +export DH_GOLANG_EXCLUDES := testdata/fuzz + %: dh $@ --buildsystem=golang --with=golang diff -Nru golang-blitiri-go-spf-1.1.0/dns_test.go golang-blitiri-go-spf-1.3.0/dns_test.go --- golang-blitiri-go-spf-1.1.0/dns_test.go 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/dns_test.go 1970-01-01 00:00:00.000000000 +0000 @@ -1,69 +0,0 @@ -package spf - -import ( - "flag" - "net" - "os" - "strings" - "testing" -) - -// DNS overrides for testing. - -type DNS struct { - txt map[string][]string - mx map[string][]*net.MX - ip map[string][]net.IP - addr map[string][]string - errors map[string]error -} - -func NewDNS() DNS { - return DNS{ - txt: map[string][]string{}, - mx: map[string][]*net.MX{}, - ip: map[string][]net.IP{}, - addr: map[string][]string{}, - errors: map[string]error{}, - } -} - -// Single global variable that the overridden resolvers use. -// This way it's easier to get a clean slate between tests. -var dns DNS - -func LookupTXT(domain string) (txts []string, err error) { - domain = strings.ToLower(domain) - domain = strings.TrimRight(domain, ".") - return dns.txt[domain], dns.errors[domain] -} - -func LookupMX(domain string) (mxs []*net.MX, err error) { - domain = strings.ToLower(domain) - domain = strings.TrimRight(domain, ".") - return dns.mx[domain], dns.errors[domain] -} - -func LookupIP(host string) (ips []net.IP, err error) { - host = strings.ToLower(host) - host = strings.TrimRight(host, ".") - return dns.ip[host], dns.errors[host] -} - -func LookupAddr(host string) (addrs []string, err error) { - host = strings.ToLower(host) - host = strings.TrimRight(host, ".") - return dns.addr[host], dns.errors[host] -} - -func TestMain(m *testing.M) { - dns = NewDNS() - - lookupTXT = LookupTXT - lookupMX = LookupMX - lookupIP = LookupIP - lookupAddr = LookupAddr - - flag.Parse() - os.Exit(m.Run()) -} diff -Nru golang-blitiri-go-spf-1.1.0/fuzz.go golang-blitiri-go-spf-1.3.0/fuzz.go --- golang-blitiri-go-spf-1.1.0/fuzz.go 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/fuzz.go 2021-11-20 17:24:26.000000000 +0000 @@ -10,7 +10,11 @@ package spf -import "net" +import ( + "net" + + "blitiri.com.ar/go/spf/internal/dnstest" +) // Parsed IP addresses, for convenience. var ( @@ -20,51 +24,29 @@ ip6660 = net.ParseIP("2001:db8::0") ) -// Results for TXT lookups. This one is global as the values will be set by -// the fuzzer. The other lookup types are static and configured in init, see -// below). -var txtResults = map[string][]string{} +// DNS resolver to use. Will be initialized once with the expected fixtures, +// and then reused on each fuzz run. +var dns = dnstest.NewResolver() func init() { - // Make the resolving functions return our test data. - // The test data is fixed, the fuzzer doesn't change it. - // TODO: Once go-fuzz can run functions from _test.go files, move this to - // spf_test.go to avoid duplicating all this boilerplate. - var ( - mxResults = map[string][]*net.MX{} - ipResults = map[string][]net.IP{} - addrResults = map[string][]string{} - ) - - lookupTXT = func(domain string) (txts []string, err error) { - return txtResults[domain], nil - } - lookupMX = func(domain string) (mxs []*net.MX, err error) { - return mxResults[domain], nil - } - lookupIP = func(host string) (ips []net.IP, err error) { - return ipResults[host], nil - } - lookupAddr = func(host string) (addrs []string, err error) { - return addrResults[host], nil - } - - ipResults["d1111"] = []net.IP{ip1111} - ipResults["d1110"] = []net.IP{ip1110} - mxResults["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}} - ipResults["d6666"] = []net.IP{ip6666} - ipResults["d6660"] = []net.IP{ip6660} - mxResults["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}} - addrResults["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."} - addrResults["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."} + dns.Ip["d1111"] = []net.IP{ip1111} + dns.Ip["d1110"] = []net.IP{ip1110} + dns.Mx["d1110"] = []*net.MX{{"d1110", 5}, {"nothing", 10}} + dns.Ip["d6666"] = []net.IP{ip6666} + dns.Ip["d6660"] = []net.IP{ip6660} + dns.Mx["d6660"] = []*net.MX{{"d6660", 5}, {"nothing", 10}} + dns.Addr["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."} + dns.Addr["1.1.1.1"] = []string{"lalala.", "domain.", "d1111."} } func Fuzz(data []byte) int { // The domain's TXT record comes from the fuzzer. - txtResults["domain"] = []string{string(data)} + dns.Txt["domain"] = []string{string(data)} - v4result, _ := CheckHost(ip1111, "domain") // IPv4 - v6result, _ := CheckHost(ip6666, "domain") // IPv6 + v4result, _ := CheckHostWithSender( + ip1111, "helo", "domain", WithResolver(dns)) + v6result, _ := CheckHostWithSender( + ip6666, "helo", "domain", WithResolver(dns)) // Raise priority if any of the results was something other than // PermError, as it means the data was better formed. diff -Nru golang-blitiri-go-spf-1.1.0/.gitlab-ci.yml golang-blitiri-go-spf-1.3.0/.gitlab-ci.yml --- golang-blitiri-go-spf-1.1.0/.gitlab-ci.yml 1970-01-01 00:00:00.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/.gitlab-ci.yml 2021-11-20 17:24:26.000000000 +0000 @@ -0,0 +1,25 @@ +# Configuration for the GitLab CI. + +# Go tests, on various Go versions. +.golang_template: &golang + stage: test + script: + - go test ./... + - go test -race ./... + +golang_1.15: + <<: *golang + image: golang:1.15 # Oldest supported version (for now). + +golang_latest: + <<: *golang + image: golang:latest + +coverage: + <<: *golang + image: golang:latest + script: + - go test -covermode=count -coverprofile=coverage.out + - go get github.com/mattn/goveralls + - goveralls -coverprofile=coverage.out -service=gitlab -repotoken=$COVERALLS_TOKEN + diff -Nru golang-blitiri-go-spf-1.1.0/go.mod golang-blitiri-go-spf-1.3.0/go.mod --- golang-blitiri-go-spf-1.1.0/go.mod 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/go.mod 2021-11-20 17:24:26.000000000 +0000 @@ -1,5 +1,5 @@ module blitiri.com.ar/go/spf -go 1.14 +go 1.15 require gopkg.in/yaml.v2 v2.3.0 diff -Nru golang-blitiri-go-spf-1.1.0/go.sum golang-blitiri-go-spf-1.3.0/go.sum --- golang-blitiri-go-spf-1.1.0/go.sum 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/go.sum 2021-11-20 17:24:26.000000000 +0000 @@ -1,3 +1,4 @@ +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.3.0 h1:clyUAQHOM3G0M3f5vQj7LuJrETvjVot3Z5el9nffUtU= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff -Nru golang-blitiri-go-spf-1.1.0/internal/dnstest/dns.go golang-blitiri-go-spf-1.3.0/internal/dnstest/dns.go --- golang-blitiri-go-spf-1.1.0/internal/dnstest/dns.go 1970-01-01 00:00:00.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/internal/dnstest/dns.go 2021-11-20 17:24:26.000000000 +0000 @@ -0,0 +1,80 @@ +// DNS resolver for testing purposes. +// +// In the future, when go fuzz can make use of _test.go files, we can rename +// this file dns_test.go and remove this extra package entirely. +// Until then, unfortunately this is the most reasonable way to share these +// helpers between go and fuzz tests. +package dnstest + +import ( + "context" + "net" + "strings" +) + +// Testing DNS resolver. +// +// Not exported since this is not part of the public API and only used +// internally on tests. +// +type TestResolver struct { + Txt map[string][]string + Mx map[string][]*net.MX + Ip map[string][]net.IP + Addr map[string][]string + Errors map[string]error +} + +func NewResolver() *TestResolver { + return &TestResolver{ + Txt: map[string][]string{}, + Mx: map[string][]*net.MX{}, + Ip: map[string][]net.IP{}, + Addr: map[string][]string{}, + Errors: map[string]error{}, + } +} + +func (r *TestResolver) LookupTXT(ctx context.Context, domain string) (txts []string, err error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + domain = strings.ToLower(domain) + domain = strings.TrimRight(domain, ".") + return r.Txt[domain], r.Errors[domain] +} + +func (r *TestResolver) LookupMX(ctx context.Context, domain string) (mxs []*net.MX, err error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + domain = strings.ToLower(domain) + domain = strings.TrimRight(domain, ".") + return r.Mx[domain], r.Errors[domain] +} + +func (r *TestResolver) LookupIPAddr(ctx context.Context, host string) (as []net.IPAddr, err error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + host = strings.ToLower(host) + host = strings.TrimRight(host, ".") + return ipsToAddrs(r.Ip[host]), r.Errors[host] +} + +func ipsToAddrs(ips []net.IP) []net.IPAddr { + as := []net.IPAddr{} + for _, ip := range ips { + as = append(as, net.IPAddr{IP: ip, Zone: ""}) + } + return as +} + +func (r *TestResolver) LookupAddr(ctx context.Context, host string) (addrs []string, err error) { + if ctx.Err() != nil { + return nil, ctx.Err() + } + host = strings.ToLower(host) + host = strings.TrimRight(host, ".") + return r.Addr[host], r.Errors[host] +} diff -Nru golang-blitiri-go-spf-1.1.0/README.md golang-blitiri-go-spf-1.3.0/README.md --- golang-blitiri-go-spf-1.1.0/README.md 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/README.md 2021-11-20 17:24:26.000000000 +0000 @@ -1,15 +1,16 @@ # blitiri.com.ar/go/spf -[![GoDoc](https://godoc.org/blitiri.com.ar/go/spf?status.svg)](https://godoc.org/blitiri.com.ar/go/spf) -[![Build Status](https://travis-ci.org/albertito/spf.svg?branch=master)](https://travis-ci.org/albertito/spf) +[![GoDoc](https://godoc.org/blitiri.com.ar/go/spf?status.svg)](https://pkg.go.dev/blitiri.com.ar/go/spf) +[![Build Status](https://gitlab.com/albertito/spf/badges/master/pipeline.svg)](https://gitlab.com/albertito/spf/-/commits/master) [![Go Report Card](https://goreportcard.com/badge/github.com/albertito/spf)](https://goreportcard.com/report/github.com/albertito/spf) [![Coverage Status](https://coveralls.io/repos/github/albertito/spf/badge.svg?branch=next)](https://coveralls.io/github/albertito/spf) [spf](https://godoc.org/blitiri.com.ar/go/spf) is an open source implementation of the Sender Policy Framework (SPF) in Go. -It is used by the [chasquid](https://blitiri.com.ar/p/chasquid/) SMTP server. +It is used by the [chasquid](https://blitiri.com.ar/p/chasquid/) and +[maddy](https://maddy.email) SMTP servers. ## Example @@ -27,8 +28,8 @@ } ``` -See the [documentation](https://godoc.org/blitiri.com.ar/go/spf) for more -details. +See the [package documentation](https://pkg.go.dev/blitiri.com.ar/go/spf) for +more details. ## Status @@ -40,9 +41,9 @@ ## Contact If you have any questions, comments or patches please send them to the mailing -list, chasquid@googlegroups.com. +list, `chasquid@googlegroups.com`. -To subscribe, send an email to chasquid+subscribe@googlegroups.com. +To subscribe, send an email to `chasquid+subscribe@googlegroups.com`. You can also browse the [archives](https://groups.google.com/forum/#!forum/chasquid). diff -Nru golang-blitiri-go-spf-1.1.0/spf.go golang-blitiri-go-spf-1.3.0/spf.go --- golang-blitiri-go-spf-1.1.0/spf.go 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/spf.go 2021-11-20 17:24:26.000000000 +0000 @@ -5,18 +5,19 @@ // exchangers to check that incoming mail from a domain comes from a host // authorized by that domain's administrators [Wikipedia]. // -// This is a Go implementation of it, which is used by the chasquid SMTP -// server (https://blitiri.com.ar/p/chasquid/). +// This package is intended to be used by SMTP servers to implement SPF +// validation. // -// Supported mechanisms and modifiers: +// All mechanisms and modifiers are supported: // all // include // a // mx +// ptr // ip4 // ip6 -// redirect // exists +// redirect // exp (ignored) // Macros // @@ -26,6 +27,8 @@ package spf // import "blitiri.com.ar/go/spf" import ( + "context" + "errors" "fmt" "net" "net/url" @@ -34,15 +37,6 @@ "strings" ) -// Functions that we can override for testing purposes. -var ( - lookupTXT = net.LookupTXT - lookupMX = net.LookupMX - lookupIP = net.LookupIP - lookupAddr = net.LookupAddr - trace = func(f string, a ...interface{}) {} -) - // The Result of an SPF check. Note the values have meaning, we use them in // headers. https://tools.ietf.org/html/rfc7208#section-8 type Result string @@ -85,52 +79,171 @@ '?': Neutral, } +// Errors returned by the library. Note that the errors returned in different +// situations may change over time, and new ones may be added. Be careful +// about over-relying on these. var ( - errLookupLimitReached = fmt.Errorf("lookup limit reached") - errUnknownField = fmt.Errorf("unknown field") - errInvalidIP = fmt.Errorf("invalid ipX value") - errInvalidMask = fmt.Errorf("invalid mask") - errInvalidMacro = fmt.Errorf("invalid macro") - errInvalidDomain = fmt.Errorf("invalid domain") - errNoResult = fmt.Errorf("no DNS record found") - errMultipleRecords = fmt.Errorf("multiple matching DNS records") - errTooManyMXRecords = fmt.Errorf("too many MX records") - - errMatchedAll = fmt.Errorf("matched 'all'") - errMatchedA = fmt.Errorf("matched 'a'") - errMatchedIP = fmt.Errorf("matched 'ip'") - errMatchedMX = fmt.Errorf("matched 'mx'") - errMatchedPTR = fmt.Errorf("matched 'ptr'") - errMatchedExists = fmt.Errorf("matched 'exists'") + // Errors related to an invalid SPF record. + ErrUnknownField = errors.New("unknown field") + ErrInvalidIP = errors.New("invalid ipX value") + ErrInvalidMask = errors.New("invalid mask") + ErrInvalidMacro = errors.New("invalid macro") + ErrInvalidDomain = errors.New("invalid domain") + + // Errors related to DNS lookups. + // Note that the library functions may also return net.DNSError. + ErrNoResult = errors.New("no DNS record found") + ErrLookupLimitReached = errors.New("lookup limit reached") + ErrTooManyMXRecords = errors.New("too many MX records") + ErrMultipleRecords = errors.New("multiple matching DNS records") + + // Errors returned on a successful match. + ErrMatchedAll = errors.New("matched all") + ErrMatchedA = errors.New("matched a") + ErrMatchedIP = errors.New("matched ip") + ErrMatchedMX = errors.New("matched mx") + ErrMatchedPTR = errors.New("matched ptr") + ErrMatchedExists = errors.New("matched exists") ) +// Default value for the maximum number of DNS lookups while resolving SPF. +// RFC is quite clear 10 must be the maximum allowed. +// https://tools.ietf.org/html/rfc7208#section-4.6.4 +const defaultMaxLookups = 10 + +// TraceFunc is the type of tracing functions. +type TraceFunc func(f string, a ...interface{}) + +var ( + nullTrace = func(f string, a ...interface{}) {} + defaultTrace = nullTrace +) + +// Option type, for setting options. Users are expected to treat this as an +// opaque type and not rely on the implementation, which is subject to change. +type Option func(*resolution) + // CheckHost fetches SPF records for `domain`, parses them, and evaluates them // to determine if `ip` is permitted to send mail for it. // Because it doesn't receive enough information to handle macros well, its // usage is not recommended, but remains supported for backwards // compatibility. +// +// The function returns a Result, which corresponds with the SPF result for +// the check as per RFC, as well as an error for debugging purposes. Note that +// the error may be non-nil even on successful checks. +// // Reference: https://tools.ietf.org/html/rfc7208#section-4 +// +// Deprecated: use CheckHostWithSender instead. func CheckHost(ip net.IP, domain string) (Result, error) { - trace("check host %q %q", ip, domain) - r := &resolution{ip, 0, "@" + domain, nil} + r := &resolution{ + ip: ip, + maxcount: defaultMaxLookups, + helo: domain, + sender: "@" + domain, + ctx: context.TODO(), + resolver: defaultResolver, + trace: defaultTrace, + } return r.Check(domain) } // CheckHostWithSender fetches SPF records for `sender`'s domain, parses them, // and evaluates them to determine if `ip` is permitted to send mail for it. // The `helo` domain is used if the sender has no domain part. +// +// The `opts` optional parameter can be used to adjust some specific +// behaviours, such as the maximum number of DNS lookups allowed. +// +// The function returns a Result, which corresponds with the SPF result for +// the check as per RFC, as well as an error for debugging purposes. Note that +// the error may be non-nil even on successful checks. +// // Reference: https://tools.ietf.org/html/rfc7208#section-4 -func CheckHostWithSender(ip net.IP, helo, sender string) (Result, error) { +func CheckHostWithSender(ip net.IP, helo, sender string, opts ...Option) (Result, error) { _, domain := split(sender) if domain == "" { domain = helo } - trace("check host with sender %q %q %q (%q)", ip, helo, sender, domain) - r := &resolution{ip, 0, sender, nil} + r := &resolution{ + ip: ip, + maxcount: defaultMaxLookups, + helo: helo, + sender: sender, + ctx: context.TODO(), + resolver: defaultResolver, + trace: defaultTrace, + } + + for _, opt := range opts { + opt(r) + } + return r.Check(domain) } +// OverrideLookupLimit overrides the maximum number of DNS lookups allowed +// during SPF evaluation. Note that using this violates the RFC, which is +// quite explicit that the maximum allowed MUST be 10 (the default). Please +// use with care. +// +// This is EXPERIMENTAL for now, and the API is subject to change. +func OverrideLookupLimit(limit uint) Option { + return func(r *resolution) { + r.maxcount = limit + } +} + +// WithContext is an option to set the context for this operation, which will +// be passed along to the resolver functions and other external calls if +// needed. +// +// This is EXPERIMENTAL for now, and the API is subject to change. +func WithContext(ctx context.Context) Option { + return func(r *resolution) { + r.ctx = ctx + } +} + +// DNSResolver implements the methods we use to resolve DNS queries. +// It is intentionally compatible with *net.Resolver. +type DNSResolver interface { + LookupTXT(ctx context.Context, name string) ([]string, error) + LookupMX(ctx context.Context, name string) ([]*net.MX, error) + LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) + LookupAddr(ctx context.Context, addr string) (names []string, err error) +} + +var defaultResolver DNSResolver = net.DefaultResolver + +// WithResolver sets the resolver to use for DNS lookups. It can be useful for +// testing, and for customize DNS resolution specifically for this library. +// +// The default is to use net.DefaultResolver, which should be appropriate for +// most users. +// +// This is EXPERIMENTAL for now, and the API is subject to change. +func WithResolver(resolver DNSResolver) Option { + return func(r *resolution) { + r.resolver = resolver + } +} + +// WithTraceFunc sets the resolver's trace function. +// +// This can be used for debugging. The trace messages are NOT machine +// parseable, and are NOT stable. They should also NOT be included in +// user-visible output, as they may include sensitive details. +// +// This is EXPERIMENTAL for now, and the API is subject to change. +func WithTraceFunc(trace TraceFunc) Option { + return func(r *resolution) { + r.trace = trace + } +} + // split an user@domain address into user and domain. func split(addr string) (string, string) { ps := strings.SplitN(addr, "@", 2) @@ -142,43 +255,54 @@ } type resolution struct { - ip net.IP - count uint + ip net.IP + count uint + maxcount uint + helo string sender string // Result of doing a reverse lookup for ip (so we only do it once). ipNames []string + + // Context for this resolution. + ctx context.Context + + // DNS resolver to use. + resolver DNSResolver + + // Trace function, used for debugging. + trace TraceFunc } -var aField = regexp.MustCompile(`^a$|a:|a/`) -var mxField = regexp.MustCompile(`^mx$|mx:|mx/`) -var ptrField = regexp.MustCompile(`^ptr$|ptr:`) +var aField = regexp.MustCompile(`^(a$|a:|a/)`) +var mxField = regexp.MustCompile(`^(mx$|mx:|mx/)`) +var ptrField = regexp.MustCompile(`^(ptr$|ptr:)`) func (r *resolution) Check(domain string) (Result, error) { r.count++ - trace("check %s %d", domain, r.count) - txt, err := getDNSRecord(domain) + r.trace("check %q %d", domain, r.count) + txt, err := r.getDNSRecord(domain) if err != nil { if isTemporary(err) { - trace("dns temp error: %v", err) + r.trace("dns temp error: %v", err) return TempError, err } - if err == errMultipleRecords { - trace("multiple dns records") + if err == ErrMultipleRecords { + r.trace("multiple dns records") return PermError, err } // Could not resolve the name, it may be missing the record. // https://tools.ietf.org/html/rfc7208#section-2.6.1 - trace("dns perm error: %v", err) + r.trace("dns perm error: %v", err) return None, err } - trace("dns record %q", txt) + r.trace("dns record %q", txt) if txt == "" { // No record => None. // https://tools.ietf.org/html/rfc7208#section-4.5 - return None, errNoResult + return None, ErrNoResult } fields := strings.Split(txt, " ") @@ -196,7 +320,7 @@ if len(redirects) > 1 { // At most a single redirect is allowed. // https://tools.ietf.org/html/rfc7208#section-6 - return PermError, errInvalidDomain + return PermError, ErrInvalidDomain } fields = append(newfields, redirects...) @@ -212,11 +336,11 @@ continue } - // Limit the number of resolutions to 10 + // Limit the number of resolutions. // https://tools.ietf.org/html/rfc7208#section-4.6.4 - if r.count > 10 { - trace("lookup limit reached") - return PermError, errLookupLimitReached + if r.count > r.maxcount { + r.trace("lookup limit reached") + return PermError, ErrLookupLimitReached } // See if we have a qualifier, defaulting to + (pass). @@ -234,54 +358,54 @@ if lfield == "all" { // https://tools.ietf.org/html/rfc7208#section-5.1 - trace("%v matched all", result) - return result, errMatchedAll + r.trace("%v matched all", result) + return result, ErrMatchedAll } else if strings.HasPrefix(lfield, "include:") { if ok, res, err := r.includeField(result, field, domain); ok { - trace("include ok, %v %v", res, err) + r.trace("include ok, %v %v", res, err) return res, err } } else if aField.MatchString(lfield) { if ok, res, err := r.aField(result, field, domain); ok { - trace("a ok, %v %v", res, err) + r.trace("a ok, %v %v", res, err) return res, err } } else if mxField.MatchString(lfield) { if ok, res, err := r.mxField(result, field, domain); ok { - trace("mx ok, %v %v", res, err) + r.trace("mx ok, %v %v", res, err) return res, err } } else if strings.HasPrefix(lfield, "ip4:") || strings.HasPrefix(lfield, "ip6:") { if ok, res, err := r.ipField(result, field); ok { - trace("ip ok, %v %v", res, err) + r.trace("ip ok, %v %v", res, err) return res, err } } else if ptrField.MatchString(lfield) { if ok, res, err := r.ptrField(result, field, domain); ok { - trace("ptr ok, %v %v", res, err) + r.trace("ptr ok, %v %v", res, err) return res, err } } else if strings.HasPrefix(lfield, "exists:") { if ok, res, err := r.existsField(result, field, domain); ok { - trace("exists ok, %v %v", res, err) + r.trace("exists ok, %v %v", res, err) return res, err } } else if strings.HasPrefix(lfield, "exp=") { - trace("exp= not used, skipping") + r.trace("exp= not used, skipping") continue } else if strings.HasPrefix(lfield, "redirect=") { - trace("redirect, %q", field) + r.trace("redirect, %q", field) return r.redirectField(field, domain) } else { // http://www.openspf.org/SPF_Record_Syntax - trace("permerror, unknown field") - return PermError, errUnknownField + r.trace("permerror, unknown field") + return PermError, ErrUnknownField } } // Got to the end of the evaluation without a result => Neutral. // https://tools.ietf.org/html/rfc7208#section-4.7 - trace("fallback to neutral") + r.trace("fallback to neutral") return Neutral, nil } @@ -290,8 +414,8 @@ // https://tools.ietf.org/html/rfc7208#section-3 // https://tools.ietf.org/html/rfc7208#section-3.2 // https://tools.ietf.org/html/rfc7208#section-4.5 -func getDNSRecord(domain string) (string, error) { - txts, err := lookupTXT(domain) +func (r *resolution) getDNSRecord(domain string) (string, error) { + txts, err := r.resolver.LookupTXT(r.ctx, domain) if err != nil { return "", err } @@ -322,7 +446,7 @@ } else if l == 1 { return records[0], nil } - return "", errMultipleRecords + return "", ErrMultipleRecords } func isTemporary(err error) bool { @@ -336,18 +460,18 @@ if strings.Contains(fip, "/") { _, ipnet, err := net.ParseCIDR(fip) if err != nil { - return true, PermError, errInvalidMask + return true, PermError, ErrInvalidMask } if ipnet.Contains(r.ip) { - return true, res, errMatchedIP + return true, res, ErrMatchedIP } } else { ip := net.ParseIP(fip) if ip == nil { - return true, PermError, errInvalidIP + return true, PermError, ErrInvalidIP } if ip.Equal(r.ip) { - return true, res, errMatchedIP + return true, res, ErrMatchedIP } } @@ -364,17 +488,17 @@ } ptrDomain, err := r.expandMacros(ptrDomain, domain) if err != nil { - return true, PermError, errInvalidMacro + return true, PermError, ErrInvalidMacro } if ptrDomain == "" { - return true, PermError, errInvalidDomain + return true, PermError, ErrInvalidDomain } if r.ipNames == nil { r.ipNames = []string{} r.count++ - ns, err := lookupAddr(r.ip.String()) + ns, err := r.resolver.LookupAddr(r.ctx, r.ip.String()) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -387,15 +511,15 @@ // have some A/AAAA. // https://tools.ietf.org/html/rfc7208#section-5.5 if r.count > 10 { - return false, "", errLookupLimitReached + return false, "", ErrLookupLimitReached } r.count++ - addrs, err := lookupIP(n) + addrs, err := r.resolver.LookupIPAddr(r.ctx, n) if err != nil { // RFC explicitly says to skip domains which error here. continue } - trace("ptr forward resolution %q -> %q", n, addrs) + r.trace("ptr forward resolution %q -> %q", n, addrs) if len(addrs) > 0 { // Append the lower-case variants so we do a case-insensitive // lookup below. @@ -404,11 +528,11 @@ } } - trace("ptr evaluating %q in %q", ptrDomain, r.ipNames) + r.trace("ptr evaluating %q in %q", ptrDomain, r.ipNames) ptrDomain = strings.ToLower(ptrDomain) for _, n := range r.ipNames { if strings.HasSuffix(n, ptrDomain+".") { - return true, res, errMatchedPTR + return true, res, ErrMatchedPTR } } @@ -422,15 +546,15 @@ eDomain := field[7:] eDomain, err := r.expandMacros(eDomain, domain) if err != nil { - return true, PermError, errInvalidMacro + return true, PermError, ErrInvalidMacro } if eDomain == "" { - return true, PermError, errInvalidDomain + return true, PermError, ErrInvalidDomain } r.count++ - ips, err := lookupIP(eDomain) + ips, err := r.resolver.LookupIPAddr(r.ctx, eDomain) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -441,8 +565,8 @@ // Exists only counts if there are IPv4 matches. for _, ip := range ips { - if ip.To4() != nil { - return true, res, errMatchedExists + if ip.IP.To4() != nil { + return true, res, ErrMatchedExists } } return false, "", nil @@ -454,7 +578,7 @@ incdomain := field[len("include:"):] incdomain, err := r.expandMacros(incdomain, domain) if err != nil { - return true, PermError, errInvalidMacro + return true, PermError, ErrInvalidMacro } ir, err := r.Check(incdomain) switch ir { @@ -470,65 +594,62 @@ return true, PermError, err } - return false, "", fmt.Errorf("This should never be reached") + return false, "", fmt.Errorf("this should never be reached") } type dualMasks struct { - v4 int - v6 int + v4 net.IPMask + v6 net.IPMask } -func ipMatch(ip, tomatch net.IP, masks dualMasks) (bool, error) { - mask := -1 - if tomatch.To4() != nil && masks.v4 >= 0 { +func ipMatch(ip, tomatch net.IP, masks dualMasks) bool { + mask := net.IPMask(nil) + if tomatch.To4() != nil && masks.v4 != nil { mask = masks.v4 - } else if tomatch.To4() == nil && masks.v6 >= 0 { + } else if tomatch.To4() == nil && masks.v6 != nil { mask = masks.v6 } - if mask >= 0 { - _, ipnet, err := net.ParseCIDR( - fmt.Sprintf("%s/%d", tomatch.String(), mask)) - if err != nil { - return false, errInvalidMask - } - return ipnet.Contains(ip), nil + if mask != nil { + ipnet := net.IPNet{IP: tomatch, Mask: mask} + return ipnet.Contains(ip) } - return ip.Equal(tomatch), nil + return ip.Equal(tomatch) } var aRegexp = regexp.MustCompile(`^[aA](:([^/]+))?(/(\w+))?(//(\w+))?$`) var mxRegexp = regexp.MustCompile(`^[mM][xX](:([^/]+))?(/(\w+))?(//(\w+))?$`) func domainAndMask(re *regexp.Regexp, field, domain string) (string, dualMasks, error) { - masks := dualMasks{-1, -1} + masks := dualMasks{} groups := re.FindStringSubmatch(field) if groups != nil { if groups[2] != "" { domain = groups[2] } if groups[4] != "" { - mask4, err := strconv.Atoi(groups[4]) - if err != nil || mask4 < 0 || mask4 > 32 { - return "", masks, errInvalidMask + i, err := strconv.Atoi(groups[4]) + mask4 := net.CIDRMask(i, 32) + if err != nil || mask4 == nil { + return "", masks, ErrInvalidMask } masks.v4 = mask4 } if groups[6] != "" { - mask6, err := strconv.Atoi(groups[6]) - if err != nil || mask6 < 0 || mask6 > 128 { - return "", masks, errInvalidMask + i, err := strconv.Atoi(groups[6]) + mask6 := net.CIDRMask(i, 128) + if err != nil || mask6 == nil { + return "", masks, ErrInvalidMask } masks.v6 = mask6 } } - trace("masks on %q: %q %q %v", field, groups, domain, masks) // Test to catch malformed entries: if there's a /, there must be at least // one mask. - if strings.Contains(field, "/") && masks.v4 == -1 && masks.v6 == -1 { - return "", masks, errInvalidMask + if strings.Contains(field, "/") && masks.v4 == nil && masks.v6 == nil { + return "", masks, ErrInvalidMask } return domain, masks, nil @@ -538,16 +659,17 @@ func (r *resolution) aField(res Result, field, domain string) (bool, Result, error) { // https://tools.ietf.org/html/rfc7208#section-5.3 aDomain, masks, err := domainAndMask(aRegexp, field, domain) + r.trace("masks on %q, %q: %q %v", field, domain, aDomain, masks) if err != nil { return true, PermError, err } aDomain, err = r.expandMacros(aDomain, domain) if err != nil { - return true, PermError, errInvalidMacro + return true, PermError, ErrInvalidMacro } r.count++ - ips, err := lookupIP(aDomain) + ips, err := r.resolver.LookupIPAddr(r.ctx, aDomain) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -556,12 +678,9 @@ return false, "", err } for _, ip := range ips { - ok, err := ipMatch(r.ip, ip, masks) - if ok { - trace("mx matched %v, %v, %v", r.ip, ip, masks) - return true, res, errMatchedA - } else if err != nil { - return true, PermError, err + if ipMatch(r.ip, ip.IP, masks) { + r.trace("a matched %v, %v, %v", r.ip, ip.IP, masks) + return true, res, ErrMatchedA } } @@ -572,16 +691,17 @@ func (r *resolution) mxField(res Result, field, domain string) (bool, Result, error) { // https://tools.ietf.org/html/rfc7208#section-5.4 mxDomain, masks, err := domainAndMask(mxRegexp, field, domain) + r.trace("masks on %q, %q: %q %v", field, domain, mxDomain, masks) if err != nil { return true, PermError, err } mxDomain, err = r.expandMacros(mxDomain, domain) if err != nil { - return true, PermError, errInvalidMacro + return true, PermError, ErrInvalidMacro } r.count++ - mxs, err := lookupMX(mxDomain) + mxs, err := r.resolver.LookupMX(r.ctx, mxDomain) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -593,13 +713,13 @@ // There's an explicit maximum of 10 MX records per match. // https://tools.ietf.org/html/rfc7208#section-4.6.4 if len(mxs) > 10 { - return true, PermError, errTooManyMXRecords + return true, PermError, ErrTooManyMXRecords } mxips := []net.IP{} for _, mx := range mxs { r.count++ - ips, err := lookupIP(mx.Host) + ips, err := r.resolver.LookupIPAddr(r.ctx, mx.Host) if err != nil { // https://tools.ietf.org/html/rfc7208#section-5 if isTemporary(err) { @@ -607,15 +727,14 @@ } return false, "", err } - mxips = append(mxips, ips...) + for _, ipaddr := range ips { + mxips = append(mxips, ipaddr.IP) + } } for _, ip := range mxips { - ok, err := ipMatch(r.ip, ip, masks) - if ok { - trace("mx matched %v, %v, %v", r.ip, ip, masks) - return true, res, errMatchedMX - } else if err != nil { - return true, PermError, err + if ipMatch(r.ip, ip, masks) { + r.trace("mx matched %v, %v, %v", r.ip, ip, masks) + return true, res, ErrMatchedMX } } @@ -627,11 +746,11 @@ rDomain := field[len("redirect="):] rDomain, err := r.expandMacros(rDomain, domain) if err != nil { - return PermError, errInvalidMacro + return PermError, ErrInvalidMacro } if rDomain == "" { - return PermError, errInvalidDomain + return PermError, ErrInvalidDomain } // https://tools.ietf.org/html/rfc7208#section-6.1 @@ -656,8 +775,8 @@ // from happening in case where it matters (a, mx), but for the ones which // doesn't, prevent them from sneaking through. if strings.Contains(s, "/") { - trace("macro contains /") - return "", errInvalidDomain + r.trace("macro contains /") + return "", ErrInvalidDomain } // Bypass the complex logic if there are no macros present. @@ -693,7 +812,7 @@ inMacroDefinition = true continue } - return "", errInvalidMacro + return "", ErrInvalidMacro } if inMacroDefinition { if c != '}' { @@ -705,10 +824,10 @@ // Extract letter, digit transformer, reverse transformer, and // delimiters. groups := macroRegexp.FindStringSubmatch(macroS) - trace("macro %q: %q", macroS, groups) + r.trace("macro %q: %q", macroS, groups) macroS = "" if groups == nil { - return "", errInvalidMacro + return "", ErrInvalidMacro } letter := groups[1] @@ -718,7 +837,7 @@ // valid. digits, err = strconv.Atoi(groups[2]) if err != nil || digits <= 0 { - return "", errInvalidMacro + return "", ErrInvalidMacro } } reverse := groups[3] == "r" || groups[3] == "R" @@ -743,7 +862,7 @@ case "d": str = domain case "i": - str = r.ip.String() + str = ipToMacroStr(r.ip) case "p": // This shouldn't be used, we don't want to support it, it's // risky. "unknown" is a safe value. @@ -756,11 +875,11 @@ str = "ip6" } case "h": - str = domain + str = r.helo default: // c, r, t are allowed in exp only, and we don't expand macros // in exp so they are just as invalid as the rest. - return "", errInvalidMacro + return "", ErrInvalidMacro } // Split str using the given separators. @@ -802,7 +921,7 @@ n += string(c) } - trace("macro expanded %q to %q", s, n) + r.trace("macro expanded %q to %q", s, n) return n, nil } @@ -811,3 +930,19 @@ a[left], a[right] = a[right], a[left] } } + +func ipToMacroStr(ip net.IP) string { + if ip.To4() != nil { + return ip.String() + } + + // For IPv6 addresses, the "i" macro expands to a dot-format address. + // https://datatracker.ietf.org/doc/html/rfc7208#section-7.3 + sb := strings.Builder{} + sb.Grow(64) + for _, b := range ip.To16() { + fmt.Fprintf(&sb, "%x.%x.", b>>4, b&0xf) + } + // Return the string without the trailing ".". + return sb.String()[:sb.Len()-1] +} diff -Nru golang-blitiri-go-spf-1.1.0/spf_test.go golang-blitiri-go-spf-1.3.0/spf_test.go --- golang-blitiri-go-spf-1.1.0/spf_test.go 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/spf_test.go 2021-11-20 17:24:26.000000000 +0000 @@ -1,78 +1,94 @@ package spf import ( + "context" "fmt" "net" "testing" + + "blitiri.com.ar/go/spf/internal/dnstest" ) +func NewDefaultResolver() *dnstest.TestResolver { + dns := dnstest.NewResolver() + defaultResolver = dns + return dns +} + +func init() { + // Override the default resolver to make sure the tests are not using the + // one from net. Individual tests will override this as well, but just in + // case. + NewDefaultResolver() +} + var ip1110 = net.ParseIP("1.1.1.0") var ip1111 = net.ParseIP("1.1.1.1") var ip6666 = net.ParseIP("2001:db8::68") var ip6660 = net.ParseIP("2001:db8::0") func TestBasic(t *testing.T) { - dns = NewDNS() - trace = t.Logf + dns := NewDefaultResolver() + defaultTrace = t.Logf cases := []struct { txt string res Result err error }{ - {"", None, errNoResult}, - {"blah", None, errNoResult}, + {"", None, ErrNoResult}, + {"blah", None, ErrNoResult}, {"v=spf1", Neutral, nil}, {"v=spf1 ", Neutral, nil}, - {"v=spf1 -", PermError, errUnknownField}, - {"v=spf1 all", Pass, errMatchedAll}, - {"v=spf1 exp=blah +all", Pass, errMatchedAll}, - {"v=spf1 +all", Pass, errMatchedAll}, - {"v=spf1 -all ", Fail, errMatchedAll}, - {"v=spf1 ~all", SoftFail, errMatchedAll}, - {"v=spf1 ?all", Neutral, errMatchedAll}, - {"v=spf1 a ~all", SoftFail, errMatchedAll}, + {"v=spf1 -", PermError, ErrUnknownField}, + {"v=spf1 all", Pass, ErrMatchedAll}, + {"v=spf1 exp=blah +all", Pass, ErrMatchedAll}, + {"v=spf1 +all", Pass, ErrMatchedAll}, + {"v=spf1 -all ", Fail, ErrMatchedAll}, + {"v=spf1 ~all", SoftFail, ErrMatchedAll}, + {"v=spf1 ?all", Neutral, ErrMatchedAll}, + {"v=spf1 a ~all", SoftFail, ErrMatchedAll}, {"v=spf1 a/24", Neutral, nil}, - {"v=spf1 a:d1110/24", Pass, errMatchedA}, - {"v=spf1 a:d1110/montoto", PermError, errInvalidMask}, - {"v=spf1 a:d1110/99", PermError, errInvalidMask}, + {"v=spf1 a:d1110/24", Pass, ErrMatchedA}, + {"v=spf1 a:d1110/montoto", PermError, ErrInvalidMask}, + {"v=spf1 a:d1110/99", PermError, ErrInvalidMask}, {"v=spf1 a:d1110/32", Neutral, nil}, {"v=spf1 a:d1110", Neutral, nil}, - {"v=spf1 a:d1111", Pass, errMatchedA}, + {"v=spf1 a:d1111", Pass, ErrMatchedA}, {"v=spf1 a:nothing/24", Neutral, nil}, {"v=spf1 mx", Neutral, nil}, {"v=spf1 mx/24", Neutral, nil}, - {"v=spf1 mx:a/montoto ~all", PermError, errInvalidMask}, - {"v=spf1 mx:d1110/24 ~all", Pass, errMatchedMX}, - {"v=spf1 mx:d1110/24//100 ~all", Pass, errMatchedMX}, - {"v=spf1 mx:d1110/24//129 ~all", PermError, errInvalidMask}, - {"v=spf1 mx:d1110/24/100 ~all", PermError, errInvalidMask}, - {"v=spf1 mx:d1110/99 ~all", PermError, errInvalidMask}, - {"v=spf1 ip4:1.2.3.4 ~all", SoftFail, errMatchedAll}, - {"v=spf1 ip6:12 ~all", PermError, errInvalidIP}, - {"v=spf1 ip4:1.1.1.1 -all", Pass, errMatchedIP}, - {"v=spf1 ip4:1.1.1.1/24 -all", Pass, errMatchedIP}, - {"v=spf1 ip4:1.1.1.1/lala -all", PermError, errInvalidMask}, - {"v=spf1 ip4:1.1.1.1/33 -all", PermError, errInvalidMask}, - {"v=spf1 include:doesnotexist", PermError, errNoResult}, - {"v=spf1 ptr -all", Pass, errMatchedPTR}, - {"v=spf1 ptr:d1111 -all", Pass, errMatchedPTR}, - {"v=spf1 ptr:lalala -all", Pass, errMatchedPTR}, - {"v=spf1 ptr:doesnotexist -all", Fail, errMatchedAll}, - {"v=spf1 blah", PermError, errUnknownField}, - {"v=spf1 exists:d1111 -all", Pass, errMatchedExists}, - {"v=spf1 redirect=", PermError, errInvalidDomain}, - } - - dns.ip["d1111"] = []net.IP{ip1111} - dns.ip["d1110"] = []net.IP{ip1110} - dns.mx["d1110"] = []*net.MX{mx("d1110", 5), mx("nothing", 10)} - dns.addr["1.1.1.1"] = []string{"lalala.", "xx.domain.", "d1111."} - dns.ip["lalala"] = []net.IP{ip1111} - dns.ip["xx.domain"] = []net.IP{ip1111} + {"v=spf1 mx:a/montoto ~all", PermError, ErrInvalidMask}, + {"v=spf1 mx:d1110/24 ~all", Pass, ErrMatchedMX}, + {"v=spf1 mx:d1110/24//100 ~all", Pass, ErrMatchedMX}, + {"v=spf1 mx:d1110/24//129 ~all", PermError, ErrInvalidMask}, + {"v=spf1 mx:d1110/24/100 ~all", PermError, ErrInvalidMask}, + {"v=spf1 mx:d1110/99 ~all", PermError, ErrInvalidMask}, + {"v=spf1 ip4:1.2.3.4 ~all", SoftFail, ErrMatchedAll}, + {"v=spf1 ip6:12 ~all", PermError, ErrInvalidIP}, + {"v=spf1 ip4:1.1.1.1 -all", Pass, ErrMatchedIP}, + {"v=spf1 ip4:1.1.1.1/24 -all", Pass, ErrMatchedIP}, + {"v=spf1 ip4:1.1.1.1/lala -all", PermError, ErrInvalidMask}, + {"v=spf1 ip4:1.1.1.1/33 -all", PermError, ErrInvalidMask}, + {"v=spf1 include:doesnotexist", PermError, ErrNoResult}, + {"v=spf1 ptr -all", Pass, ErrMatchedPTR}, + {"v=spf1 ptr:d1111 -all", Pass, ErrMatchedPTR}, + {"v=spf1 ptr:lalala -all", Pass, ErrMatchedPTR}, + {"v=spf1 ptr:doesnotexist -all", Fail, ErrMatchedAll}, + {"v=spf1 blah", PermError, ErrUnknownField}, + {"v=spf1 exists:d1111 -all", Pass, ErrMatchedExists}, + {"v=spf1 redirect=", PermError, ErrInvalidDomain}, + } + + dns.Ip["d1111"] = []net.IP{ip1111} + dns.Ip["d1110"] = []net.IP{ip1110} + dns.Mx["d1110"] = []*net.MX{mx("d1110", 5), mx("nothing", 10)} + dns.Addr["1.1.1.1"] = []string{"lalala.", "xx.domain.", "d1111."} + dns.Ip["lalala"] = []net.IP{ip1111} + dns.Ip["xx.domain"] = []net.IP{ip1111} for _, c := range cases { - dns.txt["domain"] = []string{c.txt} + dns.Txt["domain"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") if (res == TempError || res == PermError) && (err == nil) { t.Errorf("%q: expected error, got nil", c.txt) @@ -87,43 +103,43 @@ } func TestIPv6(t *testing.T) { - dns = NewDNS() - trace = t.Logf + dns := NewDefaultResolver() + defaultTrace = t.Logf cases := []struct { txt string res Result err error }{ - {"v=spf1 all", Pass, errMatchedAll}, - {"v=spf1 a ~all", SoftFail, errMatchedAll}, + {"v=spf1 all", Pass, ErrMatchedAll}, + {"v=spf1 a ~all", SoftFail, ErrMatchedAll}, {"v=spf1 a/24", Neutral, nil}, - {"v=spf1 a:d6660//24", Pass, errMatchedA}, - {"v=spf1 a:d6660/24//100", Pass, errMatchedA}, + {"v=spf1 a:d6660//24", Pass, ErrMatchedA}, + {"v=spf1 a:d6660/24//100", Pass, ErrMatchedA}, {"v=spf1 a:d6660", Neutral, nil}, - {"v=spf1 a:d6666", Pass, errMatchedA}, + {"v=spf1 a:d6666", Pass, ErrMatchedA}, {"v=spf1 a:nothing//24", Neutral, nil}, - {"v=spf1 mx:d6660//24 ~all", Pass, errMatchedMX}, - {"v=spf1 mx:d6660/24//100 ~all", Pass, errMatchedMX}, - {"v=spf1 mx:d6660/24/100 ~all", PermError, errInvalidMask}, - {"v=spf1 ip6:2001:db8::68 ~all", Pass, errMatchedIP}, - {"v=spf1 ip6:2001:db8::1/24 ~all", Pass, errMatchedIP}, - {"v=spf1 ip6:2001:db8::1/100 ~all", Pass, errMatchedIP}, - {"v=spf1 ptr -all", Pass, errMatchedPTR}, - {"v=spf1 ptr:d6666 -all", Pass, errMatchedPTR}, - {"v=spf1 ptr:sonlas6 -all", Pass, errMatchedPTR}, - {"v=spf1 ptr:sonlas7 -all", Fail, errMatchedAll}, - } - - dns.ip["d6666"] = []net.IP{ip6666} - dns.ip["d6660"] = []net.IP{ip6660} - dns.mx["d6660"] = []*net.MX{mx("d6660", 5), mx("nothing", 10)} - dns.addr["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."} - dns.ip["domain"] = []net.IP{ip1111} - dns.ip["sonlas6"] = []net.IP{ip6666} + {"v=spf1 mx:d6660//24 ~all", Pass, ErrMatchedMX}, + {"v=spf1 mx:d6660/24//100 ~all", Pass, ErrMatchedMX}, + {"v=spf1 mx:d6660/24/100 ~all", PermError, ErrInvalidMask}, + {"v=spf1 ip6:2001:db8::68 ~all", Pass, ErrMatchedIP}, + {"v=spf1 ip6:2001:db8::1/24 ~all", Pass, ErrMatchedIP}, + {"v=spf1 ip6:2001:db8::1/100 ~all", Pass, ErrMatchedIP}, + {"v=spf1 ptr -all", Pass, ErrMatchedPTR}, + {"v=spf1 ptr:d6666 -all", Pass, ErrMatchedPTR}, + {"v=spf1 ptr:sonlas6 -all", Pass, ErrMatchedPTR}, + {"v=spf1 ptr:sonlas7 -all", Fail, ErrMatchedAll}, + } + + dns.Ip["d6666"] = []net.IP{ip6666} + dns.Ip["d6660"] = []net.IP{ip6660} + dns.Mx["d6660"] = []*net.MX{mx("d6660", 5), mx("nothing", 10)} + dns.Addr["2001:db8::68"] = []string{"sonlas6.", "domain.", "d6666."} + dns.Ip["domain"] = []net.IP{ip1111} + dns.Ip["sonlas6"] = []net.IP{ip6666} for _, c := range cases { - dns.txt["domain"] = []string{c.txt} + dns.Txt["domain"] = []string{c.txt} res, err := CheckHost(ip6666, "domain") if (res == TempError || res == PermError) && (err == nil) { t.Errorf("%q: expected error, got nil", c.txt) @@ -140,25 +156,25 @@ func TestInclude(t *testing.T) { // Test that the include is doing a recursive lookup. // If we got a match on 1.1.1.1, is because include:domain2 did not match. - dns = NewDNS() - dns.txt["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"} - trace = t.Logf + dns := NewDefaultResolver() + dns.Txt["domain"] = []string{"v=spf1 include:domain2 ip4:1.1.1.1"} + defaultTrace = t.Logf cases := []struct { txt string res Result err error }{ - {"", PermError, errNoResult}, - {"v=spf1 all", Pass, errMatchedAll}, + {"", PermError, ErrNoResult}, + {"v=spf1 all", Pass, ErrMatchedAll}, // domain2 did not pass, so continued and matched parent's ip4. - {"v=spf1", Pass, errMatchedIP}, - {"v=spf1 -all", Pass, errMatchedIP}, + {"v=spf1", Pass, ErrMatchedIP}, + {"v=spf1 -all", Pass, ErrMatchedIP}, } for _, c := range cases { - dns.txt["domain2"] = []string{c.txt} + dns.Txt["domain2"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") if res != c.res || err != c.err { t.Errorf("%q: expected [%v/%v], got [%v/%v]", @@ -168,21 +184,21 @@ } func TestRecursionLimit(t *testing.T) { - dns = NewDNS() - dns.txt["domain"] = []string{"v=spf1 include:domain ~all"} - trace = t.Logf + dns := NewDefaultResolver() + dns.Txt["domain"] = []string{"v=spf1 include:domain ~all"} + defaultTrace = t.Logf res, err := CheckHost(ip1111, "domain") - if res != PermError || err != errLookupLimitReached { + if res != PermError || err != ErrLookupLimitReached { t.Errorf("expected permerror, got %v (%v)", res, err) } } func TestRedirect(t *testing.T) { - dns = NewDNS() - dns.txt["domain"] = []string{"v=spf1 redirect=domain2"} - dns.txt["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"} - trace = t.Logf + dns := NewDefaultResolver() + dns.Txt["domain"] = []string{"v=spf1 redirect=domain2"} + dns.Txt["domain2"] = []string{"v=spf1 ip4:1.1.1.1 -all"} + defaultTrace = t.Logf res, err := CheckHost(ip1111, "domain") if res != Pass { @@ -194,9 +210,9 @@ // Redirect to a non-existing host; the inner check returns None, but due // to the redirection, this lookup should return PermError. // https://tools.ietf.org/html/rfc7208#section-6.1 - dns = NewDNS() - dns.txt["domain"] = []string{"v=spf1 redirect=doesnotexist"} - trace = t.Logf + dns := NewDefaultResolver() + dns.Txt["domain"] = []string{"v=spf1 redirect=doesnotexist"} + defaultTrace = t.Logf res, err := CheckHost(ip1111, "doesnotexist") if res != None { @@ -204,7 +220,7 @@ } res, err = CheckHost(ip1111, "domain") - if res != PermError || err != errNoResult { + if res != PermError || err != ErrNoResult { t.Errorf("expected permerror, got %v (%v)", res, err) } } @@ -212,29 +228,29 @@ func TestRedirectOrder(t *testing.T) { // We should only check redirects after all mechanisms, even if the // redirect modifier appears before them. - dns = NewDNS() - dns.txt["faildom"] = []string{"v=spf1 -all"} - trace = t.Logf + dns := NewDefaultResolver() + dns.Txt["faildom"] = []string{"v=spf1 -all"} + defaultTrace = t.Logf - dns.txt["domain"] = []string{"v=spf1 redirect=faildom"} + dns.Txt["domain"] = []string{"v=spf1 redirect=faildom"} res, err := CheckHost(ip1111, "domain") - if res != Fail || err != errMatchedAll { + if res != Fail || err != ErrMatchedAll { t.Errorf("expected fail, got %v (%v)", res, err) } - dns.txt["domain"] = []string{"v=spf1 redirect=faildom all"} + dns.Txt["domain"] = []string{"v=spf1 redirect=faildom all"} res, err = CheckHost(ip1111, "domain") - if res != Pass || err != errMatchedAll { + if res != Pass || err != ErrMatchedAll { t.Errorf("expected pass, got %v (%v)", res, err) } } func TestNoRecord(t *testing.T) { - dns = NewDNS() - dns.txt["d1"] = []string{""} - dns.txt["d2"] = []string{"loco", "v=spf2"} - dns.errors["nospf"] = fmt.Errorf("no such domain") - trace = t.Logf + dns := NewDefaultResolver() + dns.Txt["d1"] = []string{""} + dns.Txt["d2"] = []string{"loco", "v=spf2"} + dns.Errors["nospf"] = fmt.Errorf("no such domain") + defaultTrace = t.Logf for _, domain := range []string{"d1", "d2", "d3", "nospf"} { res, err := CheckHost(ip1111, domain) @@ -245,17 +261,17 @@ } func TestDNSTemporaryErrors(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() dnsError := &net.DNSError{ Err: "temporary error for testing", IsTemporary: true, } // Domain "tmperr" will fail resolution with a temporary error. - dns.errors["tmperr"] = dnsError - dns.errors["1.1.1.1"] = dnsError - dns.mx["tmpmx"] = []*net.MX{mx("tmperr", 10)} - trace = t.Logf + dns.Errors["tmperr"] = dnsError + dns.Errors["1.1.1.1"] = dnsError + dns.Mx["tmpmx"] = []*net.MX{mx("tmperr", 10)} + defaultTrace = t.Logf cases := []struct { txt string @@ -269,7 +285,7 @@ } for _, c := range cases { - dns.txt["domain"] = []string{c.txt} + dns.Txt["domain"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") if res != c.res { t.Errorf("%q: expected %v, got %v (%v)", @@ -279,17 +295,17 @@ } func TestDNSPermanentErrors(t *testing.T) { - dns = NewDNS() + dns := NewDefaultResolver() dnsError := &net.DNSError{ Err: "permanent error for testing", IsTemporary: false, } // Domain "tmperr" will fail resolution with a temporary error. - dns.errors["tmperr"] = dnsError - dns.errors["1.1.1.1"] = dnsError - dns.mx["tmpmx"] = []*net.MX{mx("tmperr", 10)} - trace = t.Logf + dns.Errors["tmperr"] = dnsError + dns.Errors["1.1.1.1"] = dnsError + dns.Mx["tmpmx"] = []*net.MX{mx("tmperr", 10)} + defaultTrace = t.Logf cases := []struct { txt string @@ -303,7 +319,7 @@ } for _, c := range cases { - dns.txt["domain"] = []string{c.txt} + dns.Txt["domain"] = []string{c.txt} res, err := CheckHost(ip1111, "domain") if res != c.res { t.Errorf("%q: expected %v, got %v (%v)", @@ -313,8 +329,8 @@ } func TestMacros(t *testing.T) { - dns = NewDNS() - trace = t.Logf + dns := NewDefaultResolver() + defaultTrace = t.Logf // Most of the cases are covered by the standard test suite, so this is // targeted at gaps in coverage. @@ -323,26 +339,30 @@ res Result err error }{ - {"v=spf1 ptr:%{fff} -all", PermError, errInvalidMacro}, - {"v=spf1 mx:%{fff} -all", PermError, errInvalidMacro}, - {"v=spf1 redirect=%{fff}", PermError, errInvalidMacro}, - {"v=spf1 a:%{o0}", PermError, errInvalidMacro}, - {"v=spf1 +a:sss-%{s}-sss", Pass, errMatchedA}, - {"v=spf1 +a:ooo-%{o}-ooo", Pass, errMatchedA}, - {"v=spf1 +a:OOO-%{O}-OOO", Pass, errMatchedA}, - {"v=spf1 +a:ppp-%{p}-ppp", Pass, errMatchedA}, - {"v=spf1 +a:vvv-%{v}-vvv", Pass, errMatchedA}, - {"v=spf1 a:%{x}", PermError, errInvalidMacro}, - {"v=spf1 +a:ooo-%{o7}-ooo", Pass, errMatchedA}, - } - - dns.ip["sss-user@domain-sss"] = []net.IP{ip6666} - dns.ip["ooo-domain-ooo"] = []net.IP{ip6666} - dns.ip["ppp-unknown-ppp"] = []net.IP{ip6666} - dns.ip["vvv-ip6-vvv"] = []net.IP{ip6666} + {"v=spf1 ptr:%{fff} -all", PermError, ErrInvalidMacro}, + {"v=spf1 mx:%{fff} -all", PermError, ErrInvalidMacro}, + {"v=spf1 redirect=%{fff}", PermError, ErrInvalidMacro}, + {"v=spf1 a:%{o0}", PermError, ErrInvalidMacro}, + {"v=spf1 +a:sss-%{s}-sss", Pass, ErrMatchedA}, + {"v=spf1 +a:ooo-%{o}-ooo", Pass, ErrMatchedA}, + {"v=spf1 +a:OOO-%{O}-OOO", Pass, ErrMatchedA}, + {"v=spf1 +a:ppp-%{p}-ppp", Pass, ErrMatchedA}, + {"v=spf1 +a:hhh-%{h}-hhh", Pass, ErrMatchedA}, + {"v=spf1 +a:vvv-%{v}-vvv", Pass, ErrMatchedA}, + {"v=spf1 a:%{x}", PermError, ErrInvalidMacro}, + {"v=spf1 +a:ooo-%{o7}-ooo", Pass, ErrMatchedA}, + {"v=spf1 exists:%{ir}.vvv -all", Pass, ErrMatchedExists}, + } + + dns.Ip["sss-user@domain-sss"] = []net.IP{ip6666} + dns.Ip["ooo-domain-ooo"] = []net.IP{ip6666} + dns.Ip["ppp-unknown-ppp"] = []net.IP{ip6666} + dns.Ip["vvv-ip6-vvv"] = []net.IP{ip6666} + dns.Ip["hhh-helo-hhh"] = []net.IP{ip6666} + dns.Ip["8.6.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.vvv"] = []net.IP{ip1111} for _, c := range cases { - dns.txt["domain"] = []string{c.txt} + dns.Txt["domain"] = []string{c.txt} res, err := CheckHostWithSender(ip6666, "helo", "user@domain") if (res == TempError || res == PermError) && (err == nil) { t.Errorf("%q: expected error, got nil", c.txt) @@ -357,8 +377,8 @@ } func TestMacrosV4(t *testing.T) { - dns = NewDNS() - trace = t.Logf + dns := NewDefaultResolver() + defaultTrace = t.Logf // Like TestMacros above, but specifically for IPv4. // It's easier to have a separate suite. @@ -368,23 +388,23 @@ res Result err error }{ - {"v=spf1 +a:sr-%{sr}-sr", Pass, errMatchedA}, - {"v=spf1 +a:sra-%{sr.}-sra", Pass, errMatchedA}, - {"v=spf1 +a:o7-%{o7}-o7", Pass, errMatchedA}, - {"v=spf1 +a:o1-%{o1}-o1", Pass, errMatchedA}, - {"v=spf1 +a:o1r-%{o1r}-o1r", Pass, errMatchedA}, - {"v=spf1 +a:vvv-%{v}-vvv", Pass, errMatchedA}, + {"v=spf1 +a:sr-%{sr}-sr", Pass, ErrMatchedA}, + {"v=spf1 +a:sra-%{sr.}-sra", Pass, ErrMatchedA}, + {"v=spf1 +a:o7-%{o7}-o7", Pass, ErrMatchedA}, + {"v=spf1 +a:o1-%{o1}-o1", Pass, ErrMatchedA}, + {"v=spf1 +a:o1r-%{o1r}-o1r", Pass, ErrMatchedA}, + {"v=spf1 +a:vvv-%{v}-vvv", Pass, ErrMatchedA}, } - dns.ip["sr-com.user@domain-sr"] = []net.IP{ip1111} - dns.ip["sra-com.user@domain-sra"] = []net.IP{ip1111} - dns.ip["o7-domain.com-o7"] = []net.IP{ip1111} - dns.ip["o1-com-o1"] = []net.IP{ip1111} - dns.ip["o1r-domain-o1r"] = []net.IP{ip1111} - dns.ip["vvv-in-addr-vvv"] = []net.IP{ip1111} + dns.Ip["sr-com.user@domain-sr"] = []net.IP{ip1111} + dns.Ip["sra-com.user@domain-sra"] = []net.IP{ip1111} + dns.Ip["o7-domain.com-o7"] = []net.IP{ip1111} + dns.Ip["o1-com-o1"] = []net.IP{ip1111} + dns.Ip["o1r-domain-o1r"] = []net.IP{ip1111} + dns.Ip["vvv-in-addr-vvv"] = []net.IP{ip1111} for _, c := range cases { - dns.txt["domain.com"] = []string{c.txt} + dns.Txt["domain.com"] = []string{c.txt} res, err := CheckHostWithSender(ip1111, "helo", "user@domain.com") if (res == TempError || res == PermError) && (err == nil) { t.Errorf("%q: expected error, got nil", c.txt) @@ -401,3 +421,213 @@ func mx(host string, pref uint16) *net.MX { return &net.MX{Host: host, Pref: pref} } + +func mkDM(v4, v6 int) dualMasks { + return dualMasks{net.CIDRMask(v4, 32), net.CIDRMask(v6, 128)} +} + +func TestIPMatchHelper(t *testing.T) { + cases := []struct { + ip net.IP + tomatch net.IP + masks dualMasks + ok bool + }{ + {ip1111, ip1110, mkDM(24, -1), true}, + {ip1111, ip1111, mkDM(-1, -1), true}, + {ip1111, ip1110, mkDM(-1, -1), false}, + {ip1111, ip1110, mkDM(32, -1), false}, + {ip1111, ip1110, mkDM(99, -1), false}, + + {ip6666, ip6660, mkDM(-1, 100), true}, + {ip6666, ip6666, mkDM(-1, -1), true}, + {ip6666, ip6660, mkDM(-1, -1), false}, + {ip6666, ip6660, mkDM(-1, 128), false}, + {ip6666, ip6660, mkDM(-1, 200), false}, + } + for _, c := range cases { + ok := ipMatch(c.ip, c.tomatch, c.masks) + if ok != c.ok { + t.Errorf("[%s %s/%v]: expected %v, got %v", + c.ip, c.tomatch, c.masks, c.ok, ok) + } + } +} + +func TestInvalidMacro(t *testing.T) { + // Test that the macro expansion detects some invalid macros. + macros := []string{ + "%{x}", "%{z}", "%{c}", "%{r}", "%{t}", + } + for _, macro := range macros { + r := resolution{ + ip: ip1111, + count: 0, + sender: "sender.com", + trace: t.Logf, + } + + out, err := r.expandMacros(macro, "sender.com") + if out != "" || err != ErrInvalidMacro { + t.Errorf(`[%s]:expected ""/%v, got %q/%v`, + macro, ErrInvalidMacro, out, err) + } + } +} + +// Test that the null tracer doesn't cause unexpected issues, since all the +// other tests override it. +func TestNullTrace(t *testing.T) { + dns := NewDefaultResolver() + defaultTrace = nullTrace + + dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} + dns.Txt["domain2"] = []string{"v=spf1 +all"} + + // Do a normal resolution, check it passes. + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1") + if res != Pass { + t.Errorf("expected pass, got %q / %q", res, err) + } +} + +func TestOverrideLookupLimit(t *testing.T) { + dns := NewDefaultResolver() + defaultTrace = t.Logf + + dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} + dns.Txt["domain2"] = []string{"v=spf1 include:domain3"} + dns.Txt["domain3"] = []string{"v=spf1 include:domain4"} + dns.Txt["domain4"] = []string{"v=spf1 +all"} + + // The default of 10 should be enough. + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1") + if res != Pass { + t.Errorf("expected pass, got %q / %q", res, err) + } + + // Set the limit to 4, which is enough. + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + OverrideLookupLimit(4)) + if res != Pass { + t.Errorf("expected pass, got %q / %q", res, err) + } + + // Set the limit to 3, which is not enough. + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + OverrideLookupLimit(3)) + if res != PermError || err != ErrLookupLimitReached { + t.Errorf("expected permerror/lookup limit reached, got %q / %q", + res, err) + } +} + +func TestWithContext(t *testing.T) { + dns := NewDefaultResolver() + defaultTrace = t.Logf + + dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} + dns.Txt["domain2"] = []string{"v=spf1 +all"} + + // With a normal context. + ctx := context.Background() + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1", + WithContext(ctx)) + if res != Pass { + t.Errorf("expected pass, got %q / %q", res, err) + } + + // With a cancelled context. + ctx, cancelF := context.WithCancel(context.Background()) + cancelF() + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + WithContext(ctx)) + if res != None || err != context.Canceled { + t.Errorf("expected none/context cancelled, got %q / %q", res, err) + } +} + +func TestWithResolver(t *testing.T) { + // Use a custom resolver, making sure it's different from the default. + defaultResolver = dnstest.NewResolver() + dns := dnstest.NewResolver() + defaultTrace = t.Logf + + dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} + dns.Txt["domain2"] = []string{"v=spf1 +all"} + + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Pass { + t.Errorf("expected pass, got %q / %q", res, err) + } +} + +// Test some corner cases when resolver.LookupIPAddr returns an invalid +// address. This can happen if using a buggy custom resolver. +func TestBadResolverResponse(t *testing.T) { + dns := dnstest.NewResolver() + defaultTrace = t.Logf + + // When LookupIPAddr returns an invalid ip, for an "a" field. + dns.Ip["domain1"] = []net.IP{nil} + dns.Txt["domain1"] = []string{"v=spf1 a:domain1 -all"} + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Fail { + t.Errorf("expected fail, got %q / %q", res, err) + } + + // Same as above, except the field has a mask. + dns.Ip["domain1"] = []net.IP{nil} + dns.Txt["domain1"] = []string{"v=spf1 a:domain1//24 -all"} + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Fail { + t.Errorf("expected fail, got %q / %q", res, err) + } + + // When LookupIPAddr returns an invalid ip, for an "mx" field. + dns.Ip["mx.domain1"] = []net.IP{nil} + dns.Mx["domain1"] = []*net.MX{mx("mx.domain1", 5)} + dns.Txt["domain1"] = []string{"v=spf1 mx:domain1 -all"} + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Fail { + t.Errorf("expected fail, got %q / %q", res, err) + } + + // Same as above, except the field has a mask. + dns.Ip["mx.domain1"] = []net.IP{nil} + dns.Mx["domain1"] = []*net.MX{mx("mx.domain1", 5)} + dns.Txt["domain1"] = []string{"v=spf1 mx:domain1//24 -all"} + res, err = CheckHostWithSender(ip1111, "helo", "user@domain1", + WithResolver(dns)) + if res != Fail { + t.Errorf("expected fail, got %q / %q", res, err) + } +} + +func TestWithTraceFunc(t *testing.T) { + calls := 0 + var trace TraceFunc = func(f string, a ...interface{}) { + calls++ + t.Logf("tracing "+f, a...) + } + + dns := NewDefaultResolver() + + dns.Txt["domain1"] = []string{"v=spf1 include:domain2"} + dns.Txt["domain2"] = []string{"v=spf1 +all"} + + // Do a normal resolution, check it passes. + res, err := CheckHostWithSender(ip1111, "helo", "user@domain1", + WithTraceFunc(trace)) + if res != Pass { + t.Errorf("expected pass, got %q / %q", res, err) + } + + if calls == 0 { + t.Errorf("expected >0 trace function calls, got 0") + } +} diff -Nru golang-blitiri-go-spf-1.1.0/testdata/blitirispf-tests.yml golang-blitiri-go-spf-1.3.0/testdata/blitirispf-tests.yml --- golang-blitiri-go-spf-1.1.0/testdata/blitirispf-tests.yml 1970-01-01 00:00:00.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/testdata/blitirispf-tests.yml 2021-11-20 17:24:26.000000000 +0000 @@ -0,0 +1,132 @@ +# Simple tests, used for debugging the testing infrastructure. + +--- +description: Simple successes +tests: + test1: + description: Straightforward sucesss + helo: example.net + mailfrom: "foobar@example.net" + host: 1.2.3.4 + result: pass + test2: + description: HELO is set, but expected to be ignored + helo: blargh + mailfrom: "foobar@example.net" + host: 1.2.3.4 + result: pass +zonedata: + example.net: + - SPF: v=spf1 +all +--- +description: Simple failures +tests: + test1: + description: Straightforward failure + helo: example.net + mailfrom: "foobar@example.net" + host: 1.2.3.4 + result: fail + test2: + description: HELO is set, but expected to be ignored + helo: blargh + mailfrom: "foobar@example.net" + host: 1.2.3.4 + result: fail +zonedata: + example.net: + - SPF: v=spf1 -all +--- +description: Regexp edge cases for "a", "mx" and "ptr" +tests: + ipv6-with-a: + description: | + Send from an ip6 address that has "a:" inside. If we incorrectly parse + the "ip6" as "a", this results in a permerror since the host doesn't + match. + mailfrom: "foobar@a1.net" + host: a::a + result: pass + bad-a-mask: + description: | + If we incorrectly parse the "ip6" as "a", this results in a permerror + due to an invalid mask. + mailfrom: "foobar@a2.net" + host: 2001:db8:ff0:100::2 + result: softfail + exp-contains-mx: + description: exp= contains mx:, which should be ignored. + mailfrom: "foobar@expmx.net" + host: 1.2.3.4 + result: softfail + exp-contains-ptr: + description: | + exp= contains ptr:, which should be ignored. + Note this test case involves unusual/invalid domains. + mailfrom: "foobar@expptr.net" + host: 1.2.3.4 + result: softfail +zonedata: + a1.net: + - SPF: v=spf1 ip6:a::a ~all + a2.net: + - SPF: v=spf1 ip6:1a0a:cccc::/29 ~all + expmx.net: + - SPF: v=spf1 exp=mx:mymx.com ~all + - MX: [10, mymx.com] + mymx.com: + - A: 1.2.3.4 + expptr.net: + - SPF: v=spf1 exp=ptr:lalala.com ~all + 4.3.2.1.in-addr.arpa: + - PTR: ptr:lalala.com. + ptr:lalala.com: + - A: 1.2.3.4 +--- +description: Error on PTR forward resolution +tests: + broken-ptr-forward: + description: | + Check that if during 'ptr' forward resolution we get an error, we skip + the domain (and consequently fail the check). + mailfrom: "foo@domain.net" + host: 1.2.3.4 + result: softfail +zonedata: + domain.net: + - SPF: v=spf1 ptr:lalala.com ~all + 4.3.2.1.in-addr.arpa: + - PTR: lalala.com + lalala.com: + - TIMEOUT: true +--- +description: Permanent error on 'exists' resolution +tests: + exists-perm-error: + description: | + Check that if, during an 'exists' forward resolution we get an error, we + fail the check. + mailfrom: "foo@domain.net" + host: 1.2.3.4 + result: softfail +zonedata: + domain.net: + - SPF: v=spf1 exists:lalala.com ~all + lalala.com: + - SERVFAIL: true +--- +description: Resolve H macros correctly +tests: + resolve-h-macros: + description: | + Check that '%{h}' macros are correctly resolved to the HELO/EHLO and not + the sender domain. + mailfrom: "foo@domain.net" + helo: holahola + host: 1.2.3.4 + result: pass +zonedata: + domain.net: + - SPF: v=spf1 exists:%{h}.com ~all + holahola.com: + - A: 127.0.0.2 diff -Nru golang-blitiri-go-spf-1.1.0/testdata/simple-tests.yml golang-blitiri-go-spf-1.3.0/testdata/simple-tests.yml --- golang-blitiri-go-spf-1.1.0/testdata/simple-tests.yml 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/testdata/simple-tests.yml 1970-01-01 00:00:00.000000000 +0000 @@ -1,40 +0,0 @@ -# Simple tests, used for debugging the testing infrastructure. - ---- -description: Simple successes -tests: - test1: - description: Straightforward sucesss - helo: example.net - mailfrom: "foobar@example.net" - host: 1.2.3.4 - result: pass - test2: - description: HELO is set, but expected to be ignored - helo: blargh - mailfrom: "foobar@example.net" - host: 1.2.3.4 - result: pass -zonedata: - example.net: - - SPF: v=spf1 +all ---- -description: Simple failures -tests: - test1: - description: Straightforward failure - helo: example.net - mailfrom: "foobar@example.net" - host: 1.2.3.4 - result: fail - test2: - description: HELO is set, but expected to be ignored - helo: blargh - mailfrom: "foobar@example.net" - host: 1.2.3.4 - result: fail -zonedata: - example.net: - - SPF: v=spf1 -all - - diff -Nru golang-blitiri-go-spf-1.1.0/.travis.yml golang-blitiri-go-spf-1.3.0/.travis.yml --- golang-blitiri-go-spf-1.1.0/.travis.yml 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/.travis.yml 1970-01-01 00:00:00.000000000 +0000 @@ -1,20 +0,0 @@ -# Configuration for https://travis-ci.org/ - -language: go -dist: bionic - -go_import_path: blitiri.com.ar/go/spf - -go: - - 1.7 - - stable - - master - -before_install: - - go get github.com/mattn/goveralls - -script: - - go test ./... - - go test -race ./... - - go test -v -covermode=count -coverprofile=coverage.out - - $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci -repotoken $COVERALLS_TOKEN diff -Nru golang-blitiri-go-spf-1.1.0/yml_test.go golang-blitiri-go-spf-1.3.0/yml_test.go --- golang-blitiri-go-spf-1.1.0/yml_test.go 2020-05-22 21:32:25.000000000 +0000 +++ golang-blitiri-go-spf-1.3.0/yml_test.go 2021-11-20 17:24:26.000000000 +0000 @@ -43,14 +43,15 @@ // Only one of these will be set. type Record struct { - A stringSlice `yaml:"A"` - AAAA stringSlice `yaml:"AAAA"` - MX *MX `yaml:"MX"` - SPF stringSlice `yaml:"SPF"` - TXT stringSlice `yaml:"TXT"` - PTR stringSlice `yaml:"PTR"` - CNAME stringSlice `yaml:"CNAME"` - TIMEOUT bool `yaml:"TIMEOUT"` + A stringSlice `yaml:"A"` + AAAA stringSlice `yaml:"AAAA"` + MX *MX `yaml:"MX"` + SPF stringSlice `yaml:"SPF"` + TXT stringSlice `yaml:"TXT"` + PTR stringSlice `yaml:"PTR"` + CNAME stringSlice `yaml:"CNAME"` + TIMEOUT bool `yaml:"TIMEOUT"` + SERVFAIL bool `yaml:"SERVFAIL"` } func (r Record) String() string { @@ -78,7 +79,10 @@ if r.TIMEOUT { return "TIMEOUT" } - return fmt.Sprintf("") + if r.SERVFAIL { + return "SERVFAIL" + } + return "" } // String slice with a custom yaml unmarshaller, because the yaml parser can't @@ -146,13 +150,13 @@ suites = append(suites, s) } - trace = t.Logf + defaultTrace = t.Logf for _, suite := range suites { t.Logf("suite: %v", suite.Description) // Set up zone for the suite based on zonedata. - dns = NewDNS() + dns := NewDefaultResolver() for domain, records := range suite.ZoneData { t.Logf(" domain %v", domain) for _, record := range records { @@ -162,19 +166,27 @@ Err: "test timeout error", IsTimeout: true, } - dns.errors[domain] = err + dns.Errors[domain] = err + } + if record.SERVFAIL { + err := &net.DNSError{ + Err: "test servfail error", + IsTimeout: false, + IsTemporary: false, + } + dns.Errors[domain] = err } for _, s := range record.A { - dns.ip[domain] = append(dns.ip[domain], net.ParseIP(s)) + dns.Ip[domain] = append(dns.Ip[domain], net.ParseIP(s)) } for _, s := range record.AAAA { - dns.ip[domain] = append(dns.ip[domain], net.ParseIP(s)) + dns.Ip[domain] = append(dns.Ip[domain], net.ParseIP(s)) } for _, s := range record.TXT { - dns.txt[domain] = append(dns.txt[domain], s) + dns.Txt[domain] = append(dns.Txt[domain], s) } if record.MX != nil { - dns.mx[domain] = append(dns.mx[domain], + dns.Mx[domain] = append(dns.Mx[domain], mx(record.MX.Host, record.MX.Prio)) } for _, s := range record.PTR { @@ -189,7 +201,7 @@ s += "." } ip := reverseDNS(t, domain).String() - dns.addr[ip] = append(dns.addr[ip], s) + dns.Addr[ip] = append(dns.Addr[ip], s) } // TODO: CNAME } @@ -202,12 +214,12 @@ // only adding records from SPF if there is no TXT already. // We need to do this in a separate step because order of // appearance is not guaranteed. - if len(dns.txt[domain]) == 0 { + if len(dns.Txt[domain]) == 0 { for _, record := range records { if len(record.SPF) > 0 { // The test suite expect a single-line SPF record to be // concatenated without spaces. - dns.txt[domain] = append(dns.txt[domain], + dns.Txt[domain] = append(dns.Txt[domain], strings.Join(record.SPF, "")) } } @@ -286,8 +298,8 @@ return ip } -func TestSimple(t *testing.T) { - testRFC(t, "testdata/simple-tests.yml") +func TestOurs(t *testing.T) { + testRFC(t, "testdata/blitirispf-tests.yml") } func TestRFC4408(t *testing.T) {