diff --git a/blacklist.go b/blacklist.go new file mode 100644 index 0000000..ab99104 --- /dev/null +++ b/blacklist.go @@ -0,0 +1,79 @@ +package main + +import ( + "io/ioutil" + "log" + "net/http" + "regexp" + + "github.com/miekg/dns" +) + +func loadBlacklist(config []configBlacklist) map[string]bool { + list := make([]string, 0) + for _, element := range config { + raw, err := requestBacklist(element) + + if err != nil { + log.Printf("Failed to load blacklist %s reason: %s", element.URL, err.Error()) + continue + } + + domains := parseRawBlacklist(element, *raw) + list = append(list, domains...) + } + + // list = removeDuplicates(list) + // sort.Strings(list) + + domainMap := make(map[string]bool) + for _, e := range list { + domainMap[e] = true + } + + return domainMap +} + +func removeDuplicates(elements []string) []string { + encountered := map[string]bool{} + result := []string{} + + for v := range elements { + if !encountered[elements[v]] { + encountered[elements[v]] = true + + result = append(result, elements[v]) + } + } + + return result +} + +func requestBacklist(blacklist configBlacklist) (*string, error) { + // Request list + resp, err := http.Get(blacklist.URL) + if err != nil { + return nil, err + } + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + bodyString := string(body) + + log.Printf("Downloaded blacklist %s", blacklist.URL) + + return &bodyString, err +} + +func parseRawBlacklist(blacklist configBlacklist, raw string) []string { + finalList := make([]string, 0) + reg := regexp.MustCompile(`(?mi)^\s*(#*)\s*(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\s+([a-zA-Z0-9\.\- ]+)$`) + matches := reg.FindAllStringSubmatch(raw, -1) + for _, match := range matches { + if match[1] != "#" { + finalList = append(finalList, dns.Fqdn(match[3])) + } + } + + return finalList +} diff --git a/config.yml b/config.yml index 7ea1414..144b94b 100644 --- a/config.yml +++ b/config.yml @@ -17,7 +17,11 @@ acl: forward: acl: - - vpn + - local server: "8.8.8.8:53" -address: 0.0.0.0:8053 \ No newline at end of file +address: 0.0.0.0:8053 + +blacklist: + - url: https://raw.githubusercontent.com/anudeepND/blacklist/master/adservers.txt + format: host diff --git a/coolDns.go b/coolDns.go index 657d6e7..9548053 100644 --- a/coolDns.go +++ b/coolDns.go @@ -24,10 +24,11 @@ type zoneMap map[string][]zoneView type rrMap map[uint16]map[string][]dns.RR type config struct { - Zones []configZone `yaml:"zones"` - ACL []configACL `yaml:"acl"` - Forward configForward `yaml:"forward"` - Address string `yaml:"address"` + Zones []configZone `yaml:"zones"` + ACL []configACL `yaml:"acl"` + Forward configForward `yaml:"forward"` + Address string `yaml:"address"` + Blacklist []configBlacklist `yaml:"blacklist"` } type configForward struct { @@ -46,6 +47,11 @@ type configZone struct { ACL []string `yaml:"acl"` } +type configBlacklist struct { + URL string `yaml:"url"` + Format string `yaml:"format"` +} + var anyRecordTypes = []uint16{ dns.TypeSOA, dns.TypeA, @@ -149,7 +155,7 @@ func createACLList(config []configACL) (map[string]*net.IPNet, error) { return acls, nil } -func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet) *dns.ServeMux { +func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool) *dns.ServeMux { srv := dns.NewServeMux() c := new(dns.Client) @@ -198,15 +204,45 @@ func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet) * return } - // Forward request - in, _, err := c.Exchange(r, config.Forward.Server) + if _, ok := blacklist[r.Question[0].Name]; ok { + // Domain is blocked + m := new(dns.Msg) + m.SetReply(r) + if r.Question[0].Qtype == dns.TypeA { + m.Answer = append(m.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: r.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 1000, + }, + A: net.IPv4(0, 0, 0, 0), + }) + } else if r.Question[0].Qtype == dns.TypeAAAA { + m.Answer = append(m.Answer, &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: r.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 1000, + }, + AAAA: net.ParseIP("::/0"), + }) + } - if err != nil { - rcodeRequest(w, r, dns.RcodeServerFailure) - return + w.WriteMsg(m) + + } else { + // Forward request + in, _, err := c.Exchange(r, config.Forward.Server) + + if err != nil { + rcodeRequest(w, r, dns.RcodeServerFailure) + return + } + w.WriteMsg(in) } - w.WriteMsg(in) }) return srv @@ -338,7 +374,9 @@ func main() { log.Fatalf("Failed to parse ACL rules: %s\n", err.Error()) } - server := createServer(zones, *config, aclList) + blacklist := loadBlacklist(config.Blacklist) + + server := createServer(zones, *config, aclList, blacklist) listenAndServer(server, config.Address)