package main import ( "io/ioutil" "log" "net" "os" "os/signal" "strconv" "strings" "syscall" "github.com/miekg/dns" "gopkg.in/yaml.v3" ) type zone struct { zone string rr rrMap acl []string } type rrMap map[uint16]map[string][]dns.RR type config struct { Zones []configZone `yaml:"zones"` ACL []configACL `yaml:"acl"` Forward configForward `yaml:"forward"` } type configForward struct { ACL []string `yaml:"acl"` } type configACL struct { Name string `yaml:"name"` CIDR string `yaml:"cidr"` } type configZone struct { Zone string `yaml:"zone"` File string `yaml:"file"` ACL []string `yaml:"acl"` } func loadConfig() (*config, error) { file, err := ioutil.ReadFile("config.yml") if err != nil { return nil, err } var loadedConfig config err = yaml.Unmarshal(file, &loadedConfig) if err != nil { return nil, err } return &loadedConfig, nil } func loadZones(configZones []configZone) ([]zone, error) { zones := make([]zone, 0) for _, z := range configZones { rrs, err := loadZonefile(z.File) if err != nil { return nil, err } zones = append(zones, zone{ zone: z.Zone, rr: createRRMap(rrs), acl: z.ACL, }) log.Printf("Loaded zone %s\n", z.Zone) } return zones, nil } func createRRMap(rrs []dns.RR) rrMap { rrMap := make(rrMap) for _, rr := range rrs { if rrMap[rr.Header().Rrtype] == nil { rrMap[rr.Header().Rrtype] = make(map[string][]dns.RR) } if rrMap[rr.Header().Rrtype][rr.Header().Name] == nil { rrMap[rr.Header().Rrtype][rr.Header().Name] = make([]dns.RR, 0) } rrMap[rr.Header().Rrtype][rr.Header().Name] = append(rrMap[rr.Header().Rrtype][rr.Header().Name], rr) } return rrMap } func loadZonefile(filepath string) ([]dns.RR, error) { file, err := os.Open(filepath) if err != nil { return nil, err } parser := dns.NewZoneParser(file, "", "") var rrs = make([]dns.RR, 0) for rr, ok := parser.Next(); ok; rr, ok = parser.Next() { rrs = append(rrs, rr) } if err := parser.Err(); err != nil { log.Println(err) } return rrs, nil } func createACLList(config []configACL) (map[string]*net.IPNet, error) { acls := make(map[string]*net.IPNet) for _, aclRule := range config { _, mask, err := net.ParseCIDR(aclRule.CIDR) if err != nil { return nil, err } acls[aclRule.Name] = mask } return acls, nil } func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *dns.ServeMux { srv := dns.NewServeMux() for _, z := range zones { srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) { remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) ip := net.ParseIP(remoteIP) if len(z.acl) != 0 { passed := false for _, rule := range z.acl { if aclList[rule].Contains(ip) { passed = true } } if !passed { m := new(dns.Msg) m.SetReply(r) w.WriteMsg(m) return } } m := new(dns.Msg) m.SetReply(r) m.Authoritative = true // maybe only support one question per query like most servers do it ??? for _, q := range r.Question { rrs := z.rr[q.Qtype] m.Answer = append(m.Answer, rrs[q.Name]...) // Check for wildcard if len(m.Answer) == 0 { parts := dns.SplitDomainName(q.Name)[1:] searchDomain := "*." + dns.Fqdn(strings.Join(parts, ".")) foundDomain := rrs[searchDomain] for _, rr := range foundDomain { rr.Header().Name = q.Name m.Answer = append(m.Answer, rr) } } // Handle extras switch q.Qtype { case dns.TypeMX: // Resolve MX domains for _, mxRR := range rrs[q.Name] { if t, ok := mxRR.(*dns.MX); ok { m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Mx]...) m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Mx]...) } } case dns.TypeA, dns.TypeAAAA: if len(m.Answer) == 0 { // no A or AAAA found. Look for CNAME m.Answer = append(m.Answer, z.rr[dns.TypeCNAME][q.Name]...) if len(m.Answer) != 0 { // Resolve CNAME for _, nameRR := range m.Answer { if t, ok := nameRR.(*dns.CNAME); ok { m.Answer = append(m.Answer, z.rr[q.Qtype][t.Target]...) } } } } case dns.TypeNS: // Resove NS records for _, nsRR := range rrs[q.Name] { if t, ok := nsRR.(*dns.NS); ok { m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Ns]...) m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Ns]...) } } } } w.WriteMsg(m) }) } return srv } func listenAndServer(server *dns.ServeMux) { go func() { if err := dns.ListenAndServe(":"+strconv.Itoa(8053), "udp", server); err != nil { log.Fatalf("Failed to set udp listener %s\n", err.Error()) } }() go func() { if err := dns.ListenAndServe(":"+strconv.Itoa(8053), "tcp", server); err != nil { log.Fatalf("Failed to set tcp listener %s\n", err.Error()) } }() sig := make(chan os.Signal) signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) s := <-sig log.Printf("Signal (%v) received, stopping\n", s) os.Exit(0) } func main() { config, err := loadConfig() if err != nil { log.Fatalf("Failed to load config: %s", err.Error()) } zones, err := loadZones(config.Zones) if err != nil { log.Fatalf("Failed to load zones: %s", err.Error()) } aclList, err := createACLList(config.ACL) if err != nil { log.Fatalf("Failed to parse ACL rules: %s", err.Error()) } server := createServer(zones, *config, aclList) listenAndServer(server) }