package main import ( "errors" "io/ioutil" "log" "net" "net/http" "regexp" "github.com/miekg/dns" ) const blockTTL uint32 = 300 var nullIPv4 = net.IPv4(0, 0, 0, 0) var nullIPv6 = net.ParseIP("::/0") func loadBlacklist(config []configBlacklist) map[string]bool { list := make([]string, 0) for _, element := range config { raw, err := requestBacklist(element) if err != nil { log.Printf("Failed to load blacklist %s reason: %s", element.URL, err.Error()) continue } domains := parseRawBlacklist(element, *raw) log.Printf("Added %d blocked domains", len(domains)) list = append(list, domains...) } domainMap := make(map[string]bool) for _, e := range list { domainMap[e] = true } return domainMap } func removeDuplicates(elements []string) []string { encountered := map[string]bool{} result := []string{} for v := range elements { if !encountered[elements[v]] { encountered[elements[v]] = true result = append(result, elements[v]) } } return result } func requestBacklist(blacklist configBlacklist) (*string, error) { if blacklist.URL != "" { return getBlacklistFromURL(blacklist.URL) } return nil, errors.New("No blacklist provided") } func getBlacklistFromURL(url string) (*string, error) { // Request list resp, err := http.Get(url) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != 200 { log.Printf("Got %d status code. Continueing anyway.", resp.StatusCode) } body, err := ioutil.ReadAll(resp.Body) bodyString := string(body) log.Printf("Downloaded blacklist %s", url) return &bodyString, err } func parseRawBlacklist(blacklist configBlacklist, raw string) []string { switch blacklist.Format { case "host": return parseHostFormat(raw) default: log.Printf("Failed to parse blacklist. Format not supported: %s", blacklist.Format) log.Println("Supported types are: host") return make([]string, 0) } } func parseHostFormat(raw string) []string { finalList := make([]string, 0) reg := regexp.MustCompile(`(?mi)^\s*(#*)\s*(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\s+([a-zA-Z0-9\.\- ]+)$`) matches := reg.FindAllStringSubmatch(raw, -1) for _, match := range matches { if match[1] != "#" { finalList = append(finalList, dns.Fqdn(match[3])) } } return finalList } func handleBlockedDomain(w dns.ResponseWriter, r *dns.Msg) { q := r.Question[0] m := new(dns.Msg) m.SetReply(r) if q.Qtype == dns.TypeA { m.Answer = append(m.Answer, &dns.A{ Hdr: dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: blockTTL, }, A: nullIPv4, }) } else if q.Qtype == dns.TypeAAAA { m.Answer = append(m.Answer, &dns.AAAA{ Hdr: dns.RR_Header{ Name: q.Name, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: blockTTL, }, AAAA: nullIPv6, }) } w.WriteMsg(m) }