cool-dns/internal/cooldns.go
2021-02-02 00:57:24 +01:00

192 lines
4.3 KiB
Go

package cooldns
import (
"log"
"net"
"os"
"path/filepath"
"github.com/miekg/dns"
)
type zoneView struct {
rr rrMap
acl []string
}
type zoneMap map[string][]zoneView
type rrMap map[uint16]map[string][]dns.RR
// Start starts cooldns
func Start(configPath string) {
config, err := loadConfig(configPath)
if err != nil {
log.Fatalf("Failed to load config: %s\n", err.Error())
}
err = os.Chdir(filepath.Dir(configPath))
if err != nil {
log.Fatalf("Failed to goto config dir: %s", err.Error())
}
zones, err := loadZones(config.Zones)
if err != nil {
log.Fatalf("Failed to load zones: %s\n", err.Error())
}
aclList, err := createACLList(config.ACL)
if err != nil {
log.Fatalf("Failed to parse ACL rules: %s\n", err.Error())
}
blacklist := loadBlacklist(config.Blacklist)
var acmeMap *legoMap
if config.Lego.Enable {
acmeMap = startLEGOWebSever(config.Lego)
}
server := createServer(zones, *config, aclList, blacklist, acmeMap)
listenAndServer(server, config.Address)
if config.TLS.Enable {
listenAndServerTLS(server, config.TLS.Address, config.TLS.Cert, config.TLS.Key)
log.Printf("Start listening on tcp %s for tls", config.TLS.Address)
}
log.Printf("Start listening on udp %s and tcp %s\n", config.Address, config.Address)
}
// createServer creates a new serve mux. Adds all the logic to handle the request
func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool, acmeList *legoMap) *dns.ServeMux {
srv := dns.NewServeMux()
c := new(dns.Client)
// For all zones set from the config
for zoneName, zones := range zones {
srv.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) {
// Parse IP
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
ip := net.ParseIP(remoteIP)
if err != nil && ip != nil {
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
return
}
// Check if it is a ACME DNS-01 challange
if handleACMERequest(w, r, acmeList) {
return
}
// find out what view to handle the request
zoneIndex := -1
for i, zone := range zones {
if (len(zone.acl) == 0 && zoneIndex == -1) || checkACL(zone.acl, aclList, ip) {
zoneIndex = i
}
}
// No view found that can handle the request
if zoneIndex == -1 {
rcodeRequest(w, r, dns.RcodeRefused)
return
}
handleRequest(w, r, zones[zoneIndex])
})
}
// Handle any other request for forwarding
srv.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
// Parse IP
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
ip := net.ParseIP(remoteIP)
if err != nil && ip != nil {
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
return
}
// Check if it is a ACME DNS-01 challange
if handleACMERequest(w, r, acmeList) {
return
}
// Check ACL rules
if !checkACL(config.Forward.ACL, aclList, ip) {
rcodeRequest(w, r, dns.RcodeRefused)
return
}
// Check if the domain is bocked
if _, ok := blacklist[r.Question[0].Name]; ok {
handleBlockedDomain(w, r)
} else {
// Forward request
in, _, err := c.Exchange(r, config.Forward.Server)
if err != nil {
rcodeRequest(w, r, dns.RcodeServerFailure)
return
}
w.WriteMsg(in)
}
})
return srv
}
func listenAndServer(server *dns.ServeMux, address string) {
// Start UDP listner
go func() {
if err := dns.ListenAndServe(address, "udp", server); err != nil {
log.Fatalf("Failed to set udp listener %s\n", err.Error())
}
}()
// Start TCP listner
go func() {
if err := dns.ListenAndServe(address, "tcp", server); err != nil {
log.Fatalf("Failed to set tcp listener %s\n", err.Error())
}
}()
}
func listenAndServerTLS(server *dns.ServeMux, address, cert, key string) {
// Start TLS listner
go func() {
if err := dns.ListenAndServeTLS(address, cert, key, server); err != nil {
log.Fatalf("Failed to set DoT listener %s", err.Error())
}
}()
}
func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool {
if len(alcRules) != 0 {
passed := false
for _, rule := range alcRules {
if aclList[rule].Contains(ip) {
passed = true
}
}
return passed
}
return true
}
// rcodeRequest respond to a request with a response code
func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) {
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, rcode)
w.WriteMsg(m)
}