diff --git a/blacklist.go b/blacklist.go index 9a68ac0..4a3c49d 100644 --- a/blacklist.go +++ b/blacklist.go @@ -69,6 +69,7 @@ func getBlacklistFromURL(url string) (*string, error) { return &bodyString, err } +// parseRawBlacklist parse the raw string depending on the given format func parseRawBlacklist(blacklist configBlacklist, raw string) []string { switch blacklist.Format { case "host": @@ -82,6 +83,7 @@ func parseRawBlacklist(blacklist configBlacklist, raw string) []string { } } +// parseHostFormat parse the string in the format of a hostfile func parseHostFormat(raw string) []string { finalList := make([]string, 0) reg := regexp.MustCompile(`(?mi)^\s*(#*)\s*(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\s+([a-zA-Z0-9\.\- ]+)$`) @@ -95,6 +97,7 @@ func parseHostFormat(raw string) []string { return finalList } +// parseLineFormat one domain per line, ignore comments func parseLineFormat(raw string) []string { list := make([]string, 0) @@ -113,6 +116,7 @@ func handleBlockedDomain(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) m.SetReply(r) if q.Qtype == dns.TypeA { + // Respond with 0.0.0.0 m.Answer = append(m.Answer, &dns.A{ Hdr: dns.RR_Header{ Name: q.Name, @@ -123,6 +127,7 @@ func handleBlockedDomain(w dns.ResponseWriter, r *dns.Msg) { A: nullIPv4, }) } else if q.Qtype == dns.TypeAAAA { + // Respond with ::/0 m.Answer = append(m.Answer, &dns.AAAA{ Hdr: dns.RR_Header{ Name: q.Name, diff --git a/coolDns.go b/coolDns.go index 970d8db..6b1ff0b 100644 --- a/coolDns.go +++ b/coolDns.go @@ -23,6 +23,7 @@ type zoneMap map[string][]zoneView type rrMap map[uint16]map[string][]dns.RR +// config format of the config file type config struct { Zones []configZone `yaml:"zones"` ACL []configACL `yaml:"acl"` @@ -60,6 +61,7 @@ type configTLS struct { Key string `yaml:"key"` } +// All record types to send when a ANY request is send var anyRecordTypes = []uint16{ dns.TypeSOA, dns.TypeA, @@ -109,6 +111,7 @@ func loadZones(configZones []configZone) (zoneMap, error) { return zones, nil } +// createRRMap order the rr into a structure that is more easy to use func createRRMap(rrs []dns.RR) rrMap { rrMap := make(rrMap) for _, rr := range rrs { @@ -147,6 +150,7 @@ func loadZonefile(filepath, origin string) ([]dns.RR, error) { return rrs, nil } +// createACLList create a map with the CIDR and the name of the rule func createACLList(config []configACL) (map[string]*net.IPNet, error) { acls := make(map[string]*net.IPNet) @@ -163,10 +167,12 @@ func createACLList(config []configACL) (map[string]*net.IPNet, error) { return acls, nil } +// createServer creates a new serve mux. Adds all the logic to handle the request func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool) *dns.ServeMux { srv := dns.NewServeMux() c := new(dns.Client) + // For all zones set from the config for zoneName, zones := range zones { srv.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) { @@ -187,6 +193,7 @@ func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, b } } + // No view found that can handle the request if zoneIndex == -1 { rcodeRequest(w, r, dns.RcodeRefused) return @@ -196,8 +203,10 @@ func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, b }) } - // Handle any other request + // Handle any other request for forwarding srv.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { + + // Parse IP remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) ip := net.ParseIP(remoteIP) @@ -212,6 +221,7 @@ func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, b return } + // Check if the domain is bocked if _, ok := blacklist[r.Question[0].Name]; ok { handleBlockedDomain(w, r) } else { @@ -222,21 +232,23 @@ func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, b rcodeRequest(w, r, dns.RcodeServerFailure) return } + w.WriteMsg(in) } - }) return srv } func listenAndServer(server *dns.ServeMux, address string) { + // Start UDP listner go func() { if err := dns.ListenAndServe(address, "udp", server); err != nil { log.Fatalf("Failed to set udp listener %s\n", err.Error()) } }() + // Start TCP listner go func() { if err := dns.ListenAndServe(address, "tcp", server); err != nil { log.Fatalf("Failed to set tcp listener %s\n", err.Error()) @@ -245,6 +257,7 @@ func listenAndServer(server *dns.ServeMux, address string) { } func listenAndServerTLS(server *dns.ServeMux, address, cert, key string) { + // Start TLS listner go func() { if err := dns.ListenAndServeTLS(address, cert, key, server); err != nil { log.Fatalf("Failed to set DoT listener %s", err.Error()) @@ -266,6 +279,7 @@ func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool return true } +// rcodeRequest respond to a request with a response code func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) { m := new(dns.Msg) m.SetReply(r) @@ -273,6 +287,7 @@ func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) { w.WriteMsg(m) } +// handleRequest find the right RR(s) in the view and send them back func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) { m := new(dns.Msg) m.SetReply(r)