package cooldns import ( "errors" "io/ioutil" "log" "net" "net/http" "regexp" "strings" "github.com/miekg/dns" ) const blockTTL uint32 = 300 var nullIPv4 = net.IPv4(0, 0, 0, 0) var nullIPv6 = net.ParseIP("::/0") func loadBlacklist(config []configBlacklist) map[string]bool { list := make([]string, 0) for _, element := range config { raw, err := requestBacklist(element) if err != nil { log.Printf("Failed to load blacklist %s reason: %s", element.URL, err.Error()) continue } domains := parseRawBlacklist(element, *raw) log.Printf("Added %d blocked domains", len(domains)) list = append(list, domains...) } domainMap := make(map[string]bool) for _, e := range list { domainMap[e] = true } return domainMap } func requestBacklist(blacklist configBlacklist) (*string, error) { if blacklist.URL != "" { return getBlacklistFromURL(blacklist.URL) } return nil, errors.New("No blacklist provided") } func getBlacklistFromURL(url string) (*string, error) { // Request list resp, err := http.Get(url) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != 200 { log.Printf("Got %d status code. Continueing anyway.", resp.StatusCode) } body, err := ioutil.ReadAll(resp.Body) bodyString := string(body) log.Printf("Downloaded blacklist %s", url) return &bodyString, err } // parseRawBlacklist parse the raw string depending on the given format func parseRawBlacklist(blacklist configBlacklist, raw string) []string { switch blacklist.Format { case "host": return parseHostFormat(raw) case "line": return parseLineFormat(raw) default: log.Printf("Failed to parse blacklist. Format not supported: %s", blacklist.Format) log.Println("Supported types are: host, line") return make([]string, 0) } } // parseHostFormat parse the string in the format of a hostfile func parseHostFormat(raw string) []string { finalList := make([]string, 0) reg := regexp.MustCompile(`(?mi)^\s*(#*)\s*(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\s+([a-zA-Z0-9\.\- ]+)$`) matches := reg.FindAllStringSubmatch(raw, -1) for _, match := range matches { if match[1] != "#" { finalList = append(finalList, dns.Fqdn(match[3])) } } return finalList } // parseLineFormat one domain per line, ignore comments func parseLineFormat(raw string) []string { list := make([]string, 0) for _, line := range strings.Split(raw, "\n") { if !strings.HasPrefix(line, "#") { list = append(list, line) } } return list } func handleBlockedDomain(w dns.ResponseWriter, r *dns.Msg) { q := r.Question[0] m := new(dns.Msg) m.SetReply(r) if q.Qtype == dns.TypeA { // Respond with 0.0.0.0 m.Answer = append(m.Answer, &dns.A{ Hdr: dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: blockTTL, }, A: nullIPv4, }) } else if q.Qtype == dns.TypeAAAA { // Respond with ::/0 m.Answer = append(m.Answer, &dns.AAAA{ Hdr: dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: blockTTL, }, AAAA: nullIPv6, }) } w.WriteMsg(m) }