package main import ( "flag" "io/ioutil" "log" "net" "os" "os/signal" "strings" "syscall" "github.com/miekg/dns" "gopkg.in/yaml.v3" ) type zoneView struct { rr rrMap acl []string } type zoneMap map[string][]zoneView type rrMap map[uint16]map[string][]dns.RR // config format of the config file type config struct { Zones []configZone `yaml:"zones"` ACL []configACL `yaml:"acl"` Forward configForward `yaml:"forward"` Address string `yaml:"address"` Blacklist []configBlacklist `yaml:"blacklist"` TLS configTLS `yaml:"tls"` Lego configLego `yaml:"lego"` } type configForward struct { ACL []string `yaml:"acl"` Server string `yaml:"server"` } type configACL struct { Name string `yaml:"name"` CIDR string `yaml:"cidr"` } type configZone struct { Zone string `yaml:"zone"` File string `yaml:"file"` ACL []string `yaml:"acl"` } type configBlacklist struct { URL string `yaml:"url"` Format string `yaml:"format"` } type configTLS struct { Enable bool `yaml:"enable"` Address string `yaml:"address"` Cert string `yaml:"cert"` Key string `yaml:"key"` } // All record types to send when a ANY request is send var anyRecordTypes = []uint16{ dns.TypeSOA, dns.TypeA, dns.TypeAAAA, dns.TypeNS, dns.TypeCNAME, dns.TypeMX, dns.TypeTXT, dns.TypeSRV, dns.TypeCAA, } func loadConfig(configPath string) (*config, error) { file, err := ioutil.ReadFile(configPath) if err != nil { return nil, err } var loadedConfig config err = yaml.Unmarshal(file, &loadedConfig) if err != nil { return nil, err } return &loadedConfig, nil } func loadZones(configZones []configZone) (zoneMap, error) { zones := make(zoneMap) for _, z := range configZones { rrs, err := loadZonefile(z.File, z.Zone) if err != nil { return nil, err } if zones[z.Zone] == nil { zones[z.Zone] = make([]zoneView, 0) } zones[z.Zone] = append(zones[z.Zone], zoneView{ rr: createRRMap(rrs), acl: z.ACL, }) log.Printf("Loaded zone %s\n", z.Zone) } return zones, nil } // createRRMap order the rr into a structure that is more easy to use func createRRMap(rrs []dns.RR) rrMap { rrMap := make(rrMap) for _, rr := range rrs { if rrMap[rr.Header().Rrtype] == nil { rrMap[rr.Header().Rrtype] = make(map[string][]dns.RR) } if rrMap[rr.Header().Rrtype][rr.Header().Name] == nil { rrMap[rr.Header().Rrtype][rr.Header().Name] = make([]dns.RR, 0) } rrMap[rr.Header().Rrtype][rr.Header().Name] = append(rrMap[rr.Header().Rrtype][rr.Header().Name], rr) } return rrMap } func loadZonefile(filepath, origin string) ([]dns.RR, error) { file, err := os.Open(filepath) if err != nil { return nil, err } parser := dns.NewZoneParser(file, origin, filepath) var rrs = make([]dns.RR, 0) for rr, ok := parser.Next(); ok; rr, ok = parser.Next() { rrs = append(rrs, rr) } if err := parser.Err(); err != nil { log.Println(err) } return rrs, nil } // createACLList create a map with the CIDR and the name of the rule func createACLList(config []configACL) (map[string]*net.IPNet, error) { acls := make(map[string]*net.IPNet) for _, aclRule := range config { _, mask, err := net.ParseCIDR(aclRule.CIDR) if err != nil { return nil, err } acls[aclRule.Name] = mask } return acls, nil } // createServer creates a new serve mux. Adds all the logic to handle the request func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool, acmeList *legoMap) *dns.ServeMux { srv := dns.NewServeMux() c := new(dns.Client) // For all zones set from the config for zoneName, zones := range zones { srv.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) { // Parse IP remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) ip := net.ParseIP(remoteIP) if err != nil && ip != nil { log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error()) return } // Check if it is a ACME DNS-01 challange if handleACMERequest(w, r, acmeList) { return } // find out what view to handle the request zoneIndex := -1 for i, zone := range zones { if (len(zone.acl) == 0 && zoneIndex == -1) || checkACL(zone.acl, aclList, ip) { zoneIndex = i } } // No view found that can handle the request if zoneIndex == -1 { rcodeRequest(w, r, dns.RcodeRefused) return } handleRequest(w, r, zones[zoneIndex]) }) } // Handle any other request for forwarding srv.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { // Parse IP remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) ip := net.ParseIP(remoteIP) if err != nil && ip != nil { log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error()) return } // Check if it is a ACME DNS-01 challange if handleACMERequest(w, r, acmeList) { return } // Check ACL rules if !checkACL(config.Forward.ACL, aclList, ip) { rcodeRequest(w, r, dns.RcodeRefused) return } // Check if the domain is bocked if _, ok := blacklist[r.Question[0].Name]; ok { handleBlockedDomain(w, r) } else { // Forward request in, _, err := c.Exchange(r, config.Forward.Server) if err != nil { rcodeRequest(w, r, dns.RcodeServerFailure) return } w.WriteMsg(in) } }) return srv } func listenAndServer(server *dns.ServeMux, address string) { // Start UDP listner go func() { if err := dns.ListenAndServe(address, "udp", server); err != nil { log.Fatalf("Failed to set udp listener %s\n", err.Error()) } }() // Start TCP listner go func() { if err := dns.ListenAndServe(address, "tcp", server); err != nil { log.Fatalf("Failed to set tcp listener %s\n", err.Error()) } }() } func listenAndServerTLS(server *dns.ServeMux, address, cert, key string) { // Start TLS listner go func() { if err := dns.ListenAndServeTLS(address, cert, key, server); err != nil { log.Fatalf("Failed to set DoT listener %s", err.Error()) } }() } func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool { if len(alcRules) != 0 { passed := false for _, rule := range alcRules { if aclList[rule].Contains(ip) { passed = true } } return passed } return true } // rcodeRequest respond to a request with a response code func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) { m := new(dns.Msg) m.SetReply(r) m.SetRcode(r, rcode) w.WriteMsg(m) } // handleRequest find the right RR(s) in the view and send them back func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) { m := new(dns.Msg) m.SetReply(r) m.Authoritative = true // Only support one question per query because all the other server also does that if len(r.Question) != 1 { rcodeRequest(w, r, dns.RcodeServerFailure) } q := r.Question[0] rrs := zone.rr[q.Qtype] // Handle ANY if q.Qtype == dns.TypeANY { for _, rrType := range anyRecordTypes { m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...) } } else { // Handle any other type m.Answer = append(m.Answer, rrs[q.Name]...) // Check for wildcard if len(m.Answer) == 0 { parts := dns.SplitDomainName(q.Name)[1:] searchDomain := "*." + dns.Fqdn(strings.Join(parts, ".")) foundDomain := rrs[searchDomain] for _, rr := range foundDomain { newRR := rr newRR.Header().Name = q.Name m.Answer = append(m.Answer, newRR) } } } // Handle extras switch q.Qtype { // Dont handle extra stuff when answering ANY request // case dns.TypeANY: // fallthrough case dns.TypeMX: // Resolve MX domains for _, mxRR := range m.Answer { if t, ok := mxRR.(*dns.MX); ok { m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...) m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...) } } case dns.TypeA, dns.TypeAAAA: if len(m.Answer) == 0 { // no A or AAAA found. Look for CNAME m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...) if len(m.Answer) != 0 { // Resolve CNAME for _, nameRR := range m.Answer { if t, ok := nameRR.(*dns.CNAME); ok { m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...) } } } } case dns.TypeNS: // Resove NS records for _, nsRR := range m.Answer { if t, ok := nsRR.(*dns.NS); ok { m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...) m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...) } } } if len(m.Answer) == 0 { m.SetRcode(m, dns.RcodeNameError) } w.WriteMsg(m) } func main() { configPath := flag.String("c", "/etc/cool-dns/config.yaml", "path to the config file") flag.Parse() config, err := loadConfig(*configPath) if err != nil { log.Fatalf("Failed to load config: %s\n", err.Error()) } zones, err := loadZones(config.Zones) if err != nil { log.Fatalf("Failed to load zones: %s\n", err.Error()) } aclList, err := createACLList(config.ACL) if err != nil { log.Fatalf("Failed to parse ACL rules: %s\n", err.Error()) } blacklist := loadBlacklist(config.Blacklist) var acmeMap *legoMap if config.Lego.Enable { acmeMap = startLEGOWebSever(config.Lego) } server := createServer(zones, *config, aclList, blacklist, acmeMap) listenAndServer(server, config.Address) if config.TLS.Enable { listenAndServerTLS(server, config.TLS.Address, config.TLS.Cert, config.TLS.Key) log.Printf("Start listening on tcp %s for tls", config.TLS.Address) } log.Printf("Start listening on udp %s and tcp %s\n", config.Address, config.Address) sig := make(chan os.Signal) signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) s := <-sig log.Printf("Signal (%v) received, stopping\n", s) os.Exit(0) }