package cooldns import ( "log" "net" "os" "path/filepath" "github.com/miekg/dns" ) type zoneView struct { rr rrMap acl []string } type zoneMap map[string][]zoneView type rrMap map[uint16]map[string][]dns.RR // Start starts cooldns func Start(configPath string) { config, err := loadConfig(configPath) if err != nil { log.Fatalf("Failed to load config: %s\n", err.Error()) } err = os.Chdir(filepath.Dir(configPath)) if err != nil { log.Fatalf("Failed to goto config dir: %s", err.Error()) } zones, err := loadZones(config.Zones) if err != nil { log.Fatalf("Failed to load zones: %s\n", err.Error()) } aclList, err := createACLList(config.ACL) if err != nil { log.Fatalf("Failed to parse ACL rules: %s\n", err.Error()) } blacklist := loadBlacklist(config.Blacklist) var acmeMap *legoMap if config.Lego.Enable { acmeMap = startLEGOWebSever(config.Lego) } server := createServer(zones, *config, aclList, blacklist, acmeMap) listenAndServer(server, config.Address) if config.TLS.Enable { listenAndServerTLS(server, config.TLS.Address, config.TLS.Cert, config.TLS.Key) log.Printf("Start listening on tcp %s for tls", config.TLS.Address) } log.Printf("Start listening on udp %s and tcp %s\n", config.Address, config.Address) } // createServer creates a new serve mux. Adds all the logic to handle the request func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool, acmeList *legoMap) *dns.ServeMux { srv := dns.NewServeMux() c := new(dns.Client) // For all zones set from the config for zoneName, zones := range zones { srv.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) { // Parse IP remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) ip := net.ParseIP(remoteIP) if err != nil && ip != nil { log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error()) return } // Check if it is a ACME DNS-01 challange if handleACMERequest(w, r, acmeList) { return } // find out what view to handle the request zoneIndex := -1 for i, zone := range zones { if (len(zone.acl) == 0 && zoneIndex == -1) || checkACL(zone.acl, aclList, ip) { zoneIndex = i } } // No view found that can handle the request if zoneIndex == -1 { rcodeRequest(w, r, dns.RcodeRefused) return } handleRequest(w, r, zones[zoneIndex]) }) } // Handle any other request for forwarding srv.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { // Parse IP remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) ip := net.ParseIP(remoteIP) if err != nil && ip != nil { log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error()) return } // Check if it is a ACME DNS-01 challange if handleACMERequest(w, r, acmeList) { return } // Check ACL rules if !checkACL(config.Forward.ACL, aclList, ip) { rcodeRequest(w, r, dns.RcodeRefused) return } // Check if the domain is bocked if _, ok := blacklist[r.Question[0].Name]; ok { handleBlockedDomain(w, r) } else { // Forward request in, _, err := c.Exchange(r, config.Forward.Server) if err != nil { rcodeRequest(w, r, dns.RcodeServerFailure) return } w.WriteMsg(in) } }) return srv } func listenAndServer(server *dns.ServeMux, address string) { // Start UDP listner go func() { if err := dns.ListenAndServe(address, "udp", server); err != nil { log.Fatalf("Failed to set udp listener %s\n", err.Error()) } }() // Start TCP listner go func() { if err := dns.ListenAndServe(address, "tcp", server); err != nil { log.Fatalf("Failed to set tcp listener %s\n", err.Error()) } }() } func listenAndServerTLS(server *dns.ServeMux, address, cert, key string) { // Start TLS listner go func() { if err := dns.ListenAndServeTLS(address, cert, key, server); err != nil { log.Fatalf("Failed to set DoT listener %s", err.Error()) } }() } func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool { if len(alcRules) != 0 { passed := false for _, rule := range alcRules { if aclList[rule].Contains(ip) { passed = true } } return passed } return true } // rcodeRequest respond to a request with a response code func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) { m := new(dns.Msg) m.SetReply(r) m.SetRcode(r, rcode) w.WriteMsg(m) }