cool-dns/internal/cooldns.go

190 lines
4.3 KiB
Go
Raw Normal View History

2021-01-31 21:31:08 +00:00
package cooldns
2020-12-21 21:43:07 +00:00
import (
"log"
2020-12-22 18:07:39 +00:00
"net"
2020-12-21 21:43:07 +00:00
"os"
2021-01-08 18:12:58 +00:00
"path/filepath"
2020-12-21 21:43:07 +00:00
"github.com/miekg/dns"
)
2020-12-25 13:43:14 +00:00
type zoneView struct {
rr rrMap
acl []string
2020-12-21 21:43:07 +00:00
}
2020-12-25 13:43:14 +00:00
type zoneMap map[string][]zoneView
2021-01-31 21:31:08 +00:00
// Start starts cooldns
func Start(configPath string) {
config, err := loadConfig(configPath)
if err != nil {
log.Fatalf("Failed to load config: %s\n", err.Error())
}
err = os.Chdir(filepath.Dir(configPath))
if err != nil {
log.Fatalf("Failed to goto config dir: %s", err.Error())
}
zones, err := loadZones(config.Zones)
if err != nil {
log.Fatalf("Failed to load zones: %s\n", err.Error())
}
aclList, err := createACLList(config.ACL)
if err != nil {
log.Fatalf("Failed to parse ACL rules: %s\n", err.Error())
}
blacklist := loadBlacklist(config.Blacklist)
var acmeMap *legoMap
if config.Lego.Enable {
acmeMap = startLEGOWebSever(config.Lego)
}
server := createServer(zones, *config, aclList, blacklist, acmeMap)
listenAndServer(server, config.Address)
if config.TLS.Enable {
listenAndServerTLS(server, config.TLS.Address, config.TLS.Cert, config.TLS.Key)
log.Printf("Start listening on tcp %s for tls", config.TLS.Address)
}
log.Printf("Start listening on udp %s and tcp %s\n", config.Address, config.Address)
}
2021-01-06 14:53:58 +00:00
// createServer creates a new serve mux. Adds all the logic to handle the request
2021-01-08 15:08:57 +00:00
func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool, acmeList *legoMap) *dns.ServeMux {
2020-12-21 21:43:07 +00:00
srv := dns.NewServeMux()
2020-12-23 20:44:33 +00:00
c := new(dns.Client)
2020-12-21 21:43:07 +00:00
2021-01-06 14:53:58 +00:00
// For all zones set from the config
2020-12-25 13:43:14 +00:00
for zoneName, zones := range zones {
srv.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) {
2020-12-22 18:07:39 +00:00
2020-12-23 12:10:58 +00:00
// Parse IP
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
2020-12-23 20:58:08 +00:00
ip := net.ParseIP(remoteIP)
if err != nil && ip != nil {
2020-12-26 20:56:51 +00:00
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
2020-12-23 12:10:58 +00:00
return
}
2021-01-08 15:08:57 +00:00
// Check if it is a ACME DNS-01 challange
2021-02-03 14:00:40 +00:00
if config.Lego.Enable && handleACMERequest(w, r, acmeList) {
2021-01-08 15:08:57 +00:00
return
}
2020-12-25 13:43:14 +00:00
// find out what view to handle the request
zoneIndex := -1
2020-12-22 18:07:39 +00:00
2020-12-25 13:43:14 +00:00
for i, zone := range zones {
if (len(zone.acl) == 0 && zoneIndex == -1) || checkACL(zone.acl, aclList, ip) {
zoneIndex = i
2020-12-22 21:13:32 +00:00
}
2020-12-25 13:43:14 +00:00
}
2020-12-21 21:43:07 +00:00
2021-01-06 14:53:58 +00:00
// No view found that can handle the request
2020-12-25 13:43:14 +00:00
if zoneIndex == -1 {
rcodeRequest(w, r, dns.RcodeRefused)
return
2020-12-21 21:43:07 +00:00
}
2020-12-25 13:43:14 +00:00
handleRequest(w, r, zones[zoneIndex])
2020-12-21 21:43:07 +00:00
})
}
2021-01-06 14:53:58 +00:00
// Handle any other request for forwarding
2020-12-23 20:44:33 +00:00
srv.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
2021-01-06 14:53:58 +00:00
// Parse IP
2020-12-23 20:44:33 +00:00
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
ip := net.ParseIP(remoteIP)
if err != nil && ip != nil {
2020-12-26 20:56:51 +00:00
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
2020-12-23 20:44:33 +00:00
return
}
2021-01-08 15:08:57 +00:00
// Check if it is a ACME DNS-01 challange
2021-02-03 14:00:40 +00:00
if config.Lego.Enable && handleACMERequest(w, r, acmeList) {
2021-01-08 15:08:57 +00:00
return
}
2020-12-23 20:44:33 +00:00
// Check ACL rules
2020-12-23 20:58:08 +00:00
if !checkACL(config.Forward.ACL, aclList, ip) {
rcodeRequest(w, r, dns.RcodeRefused)
return
2020-12-23 20:44:33 +00:00
}
2021-01-06 14:53:58 +00:00
// Check if the domain is bocked
2020-12-27 21:13:13 +00:00
if _, ok := blacklist[r.Question[0].Name]; ok {
2020-12-27 23:36:54 +00:00
handleBlockedDomain(w, r)
2020-12-27 21:13:13 +00:00
} else {
// Forward request
in, _, err := c.Exchange(r, config.Forward.Server)
if err != nil {
rcodeRequest(w, r, dns.RcodeServerFailure)
return
}
2021-01-06 14:53:58 +00:00
2020-12-27 21:13:13 +00:00
w.WriteMsg(in)
2020-12-23 20:44:33 +00:00
}
})
2020-12-21 21:43:07 +00:00
return srv
}
2020-12-26 13:41:08 +00:00
func listenAndServer(server *dns.ServeMux, address string) {
2021-01-06 14:53:58 +00:00
// Start UDP listner
2020-12-21 21:43:07 +00:00
go func() {
2020-12-26 13:41:08 +00:00
if err := dns.ListenAndServe(address, "udp", server); err != nil {
2020-12-21 21:43:07 +00:00
log.Fatalf("Failed to set udp listener %s\n", err.Error())
}
}()
2021-01-06 14:53:58 +00:00
// Start TCP listner
2020-12-21 21:43:07 +00:00
go func() {
2020-12-26 13:41:08 +00:00
if err := dns.ListenAndServe(address, "tcp", server); err != nil {
2020-12-21 21:43:07 +00:00
log.Fatalf("Failed to set tcp listener %s\n", err.Error())
}
}()
}
2020-12-30 20:59:33 +00:00
func listenAndServerTLS(server *dns.ServeMux, address, cert, key string) {
2021-01-06 14:53:58 +00:00
// Start TLS listner
2020-12-30 20:59:33 +00:00
go func() {
if err := dns.ListenAndServeTLS(address, cert, key, server); err != nil {
log.Fatalf("Failed to set DoT listener %s", err.Error())
}
}()
}
2020-12-23 20:58:08 +00:00
func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool {
if len(alcRules) != 0 {
passed := false
for _, rule := range alcRules {
if aclList[rule].Contains(ip) {
passed = true
}
}
return passed
}
return true
}
2021-01-06 14:53:58 +00:00
// rcodeRequest respond to a request with a response code
2020-12-23 20:58:08 +00:00
func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) {
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, rcode)
w.WriteMsg(m)
}