cool-dns/internal/blacklist.go

142 lines
3.0 KiB
Go
Raw Permalink Normal View History

2021-01-31 21:31:08 +00:00
package cooldns
2020-12-27 21:13:13 +00:00
import (
2020-12-29 21:34:30 +00:00
"errors"
2020-12-27 21:13:13 +00:00
"io/ioutil"
"log"
2020-12-27 23:36:54 +00:00
"net"
2020-12-27 21:13:13 +00:00
"net/http"
"regexp"
2020-12-30 13:14:52 +00:00
"strings"
2020-12-27 21:13:13 +00:00
"github.com/miekg/dns"
)
2021-02-25 12:25:19 +00:00
const blockTTL uint32 = 604800
2020-12-27 23:36:54 +00:00
var nullIPv4 = net.IPv4(0, 0, 0, 0)
var nullIPv6 = net.ParseIP("::/0")
2020-12-27 21:13:13 +00:00
func loadBlacklist(config []configBlacklist) map[string]bool {
list := make([]string, 0)
for _, element := range config {
raw, err := requestBacklist(element)
if err != nil {
log.Printf("Failed to load blacklist %s reason: %s", element.URL, err.Error())
continue
}
domains := parseRawBlacklist(element, *raw)
2020-12-29 21:34:30 +00:00
log.Printf("Added %d blocked domains", len(domains))
2020-12-27 21:13:13 +00:00
list = append(list, domains...)
}
domainMap := make(map[string]bool)
for _, e := range list {
domainMap[e] = true
}
return domainMap
}
func requestBacklist(blacklist configBlacklist) (*string, error) {
2020-12-29 21:34:30 +00:00
if blacklist.URL != "" {
return getBlacklistFromURL(blacklist.URL)
}
return nil, errors.New("No blacklist provided")
}
func getBlacklistFromURL(url string) (*string, error) {
2020-12-27 21:13:13 +00:00
// Request list
2020-12-29 21:34:30 +00:00
resp, err := http.Get(url)
2020-12-27 21:13:13 +00:00
if err != nil {
return nil, err
}
defer resp.Body.Close()
2020-12-29 21:34:30 +00:00
if resp.StatusCode != 200 {
log.Printf("Got %d status code. Continueing anyway.", resp.StatusCode)
}
2020-12-27 21:13:13 +00:00
body, err := ioutil.ReadAll(resp.Body)
bodyString := string(body)
2020-12-29 21:34:30 +00:00
log.Printf("Downloaded blacklist %s", url)
2020-12-27 21:13:13 +00:00
return &bodyString, err
}
2021-01-06 14:53:58 +00:00
// parseRawBlacklist parse the raw string depending on the given format
2020-12-27 21:13:13 +00:00
func parseRawBlacklist(blacklist configBlacklist, raw string) []string {
2020-12-29 21:34:30 +00:00
switch blacklist.Format {
case "host":
return parseHostFormat(raw)
2020-12-30 13:14:52 +00:00
case "line":
return parseLineFormat(raw)
2020-12-29 21:34:30 +00:00
default:
log.Printf("Failed to parse blacklist. Format not supported: %s", blacklist.Format)
2020-12-30 13:14:52 +00:00
log.Println("Supported types are: host, line")
2020-12-29 21:34:30 +00:00
return make([]string, 0)
}
}
2021-01-06 14:53:58 +00:00
// parseHostFormat parse the string in the format of a hostfile
2020-12-29 21:34:30 +00:00
func parseHostFormat(raw string) []string {
2020-12-27 21:13:13 +00:00
finalList := make([]string, 0)
2021-02-25 12:24:57 +00:00
reg := regexp.MustCompile(`(?m)^\s*(0\.0\.0\.0) ([a-zA-Z0-9-.]*)`)
2020-12-27 21:13:13 +00:00
matches := reg.FindAllStringSubmatch(raw, -1)
for _, match := range matches {
2021-02-25 12:24:57 +00:00
finalList = append(finalList, dns.Fqdn(match[2]))
2020-12-27 21:13:13 +00:00
}
return finalList
}
2020-12-27 23:36:54 +00:00
2021-01-06 14:53:58 +00:00
// parseLineFormat one domain per line, ignore comments
2020-12-30 13:14:52 +00:00
func parseLineFormat(raw string) []string {
list := make([]string, 0)
for _, line := range strings.Split(raw, "\n") {
if !strings.HasPrefix(line, "#") {
list = append(list, line)
}
}
return list
}
2020-12-27 23:36:54 +00:00
func handleBlockedDomain(w dns.ResponseWriter, r *dns.Msg) {
q := r.Question[0]
m := new(dns.Msg)
m.SetReply(r)
if q.Qtype == dns.TypeA {
2021-01-06 14:53:58 +00:00
// Respond with 0.0.0.0
2020-12-27 23:36:54 +00:00
m.Answer = append(m.Answer, &dns.A{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: blockTTL,
},
A: nullIPv4,
})
} else if q.Qtype == dns.TypeAAAA {
2021-01-06 14:53:58 +00:00
// Respond with ::/0
2020-12-27 23:36:54 +00:00
m.Answer = append(m.Answer, &dns.AAAA{
Hdr: dns.RR_Header{
Name: q.Name,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: blockTTL,
},
AAAA: nullIPv6,
})
}
w.WriteMsg(m)
}