diff --git a/go.mod b/go.mod index 83efea7..fd0fde3 100644 --- a/go.mod +++ b/go.mod @@ -6,5 +6,6 @@ require github.com/oschwald/geoip2-golang v1.13.0 require ( github.com/oschwald/maxminddb-golang v1.13.1 // indirect + golang.org/x/sync v0.21.0 // indirect golang.org/x/sys v0.45.0 // indirect ) diff --git a/go.sum b/go.sum index 8a6726a..6698530 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/sync v0.21.0 h1:HLII4xRRTtCRkxYp4HNFF0Js/Og6q2i++KXbg0gHCwM= +golang.org/x/sync v0.21.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.45.0 h1:dO4czNzziLiiXplLQgBCEpCvXQ3dnkn0SdaZSYdQ+FY= golang.org/x/sys v0.45.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/main.go b/main.go index 914534c..49ff935 100644 --- a/main.go +++ b/main.go @@ -22,6 +22,9 @@ var verbose bool var enableIPv6 bool var url string +var resolveDomains bool + + func main() { _ = os.Unsetenv("ALL_PROXY") _ = os.Unsetenv("HTTP_PROXY") @@ -38,6 +41,7 @@ func main() { flag.BoolVar(&enableIPv6, "46", false, "Enable IPv6 in additional to IPv4") flag.StringVar(&url, "url", "", "Crawl the domain list from a URL, "+ "e.g. https://launchpad.net/ubuntu/+archivemirrors") + flag.BoolVar(&resolveDomains, "resolve-domains", false, "DNS-resolve the domain from the TLS certificate and verify if the scanned IP is among the A/AAAA records") flag.Parse() if verbose { slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ diff --git a/scanner.go b/scanner.go index 0b8af0e..9839cf2 100644 --- a/scanner.go +++ b/scanner.go @@ -1,14 +1,24 @@ package main import ( + "context" "crypto/tls" "log/slog" "net" "strconv" "strings" + "sync" "time" + + "golang.org/x/sync/singleflight" ) +// dnsCache stores only successful resolutions. +var dnsCache = make(map[string][]net.IP) +var dnsCacheMu sync.RWMutex +var dnsSF singleflight.Group + + func ScanTLS(host Host, out chan<- string, geo *Geo) { if host.IP == nil { ip, err := LookupIP(host.Origin) @@ -46,7 +56,13 @@ func ScanTLS(host Host, out chan<- string, geo *Geo) { } state := c.ConnectionState() alpn := state.NegotiatedProtocol - domain := state.PeerCertificates[0].Subject.CommonName + domain := "" + if len(state.PeerCertificates[0].DNSNames) > 0 { + domain = state.PeerCertificates[0].DNSNames[0] + } else { + domain = state.PeerCertificates[0].Subject.CommonName + } + issuers := strings.Join(state.PeerCertificates[0].Issuer.Organization, " | ") length := 0 leaf := state.PeerCertificates[0] @@ -57,13 +73,64 @@ func ScanTLS(host Host, out chan<- string, geo *Geo) { } } - log := slog.Info feasible := true geoCode := geo.GetGeo(host.IP) if state.Version != tls.VersionTLS13 || alpn != "h2" || len(domain) == 0 || len(issuers) == 0 { // not feasible - log = slog.Debug feasible = false + } + + if feasible && resolveDomains { + lookupDomain := domain + if strings.HasPrefix(lookupDomain, "*.") { + lookupDomain = lookupDomain[2:] + } + + dnsCacheMu.RLock() + resolvedIPs, cached := dnsCache[lookupDomain] + dnsCacheMu.RUnlock() + + if !cached { + v, err, _ := dnsSF.Do(lookupDomain, func() (interface{}, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + addrs, err := net.DefaultResolver.LookupIPAddr(ctx, lookupDomain) + if err != nil { + return ([]net.IP)(nil), err + } + ips := make([]net.IP, len(addrs)) + for i, a := range addrs { + ips[i] = a.IP + } + dnsCacheMu.Lock() + dnsCache[lookupDomain] = ips + dnsCacheMu.Unlock() + return ips, nil + }) + if err != nil { + slog.Debug("DNS resolution failed", "domain", lookupDomain, "err", err) + resolvedIPs = nil + } else { + resolvedIPs = v.([]net.IP) + } + } + + ipMatched := false + for _, rip := range resolvedIPs { + if rip.Equal(host.IP) { + ipMatched = true + break + } + } + if !ipMatched { + slog.Debug("IP mismatch with cert domain DNS records", "ip", host.IP.String(), "domain", lookupDomain) + feasible = false + } + } + + log := slog.Info + if !feasible { + log = slog.Debug } else { out <- strings.Join([]string{ host.IP.String(),