diff --git a/config.yml b/config.yml index f490914..aa5f263 100644 --- a/config.yml +++ b/config.yml @@ -2,9 +2,18 @@ zones: - zone: example.com. file: zonefile.txt +- zone: example.com. + file: zonefile2.txt + acl: + - lan + acl: - name: vpn cidr: 10.0.0.0/24 +- name: lan + cidr: 192.168.0.0/16 +- name: local + cidr: 127.0.0.1/32 forward: acl: diff --git a/coolDns.go b/coolDns.go index 447f883..92971e3 100644 --- a/coolDns.go +++ b/coolDns.go @@ -14,12 +14,13 @@ import ( "gopkg.in/yaml.v3" ) -type zone struct { - zone string - rr rrMap - acl []string +type zoneView struct { + rr rrMap + acl []string } +type zoneMap map[string][]zoneView + type rrMap map[uint16]map[string][]dns.RR type config struct { @@ -72,18 +73,21 @@ func loadConfig() (*config, error) { return &loadedConfig, nil } -func loadZones(configZones []configZone) ([]zone, error) { - zones := make([]zone, 0) +func loadZones(configZones []configZone) (zoneMap, error) { + zones := make(zoneMap) for _, z := range configZones { rrs, err := loadZonefile(z.File) if err != nil { return nil, err } - zones = append(zones, zone{ - zone: z.Zone, - rr: createRRMap(rrs), - acl: z.ACL, + if zones[z.Zone] == nil { + zones[z.Zone] = make([]zoneView, 0) + } + zones[z.Zone] = append(zones[z.Zone], zoneView{ + rr: createRRMap(rrs), + acl: z.ACL, }) + log.Printf("Loaded zone %s\n", z.Zone) } @@ -144,12 +148,12 @@ func createACLList(config []configACL) (map[string]*net.IPNet, error) { return acls, nil } -func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *dns.ServeMux { +func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet) *dns.ServeMux { srv := dns.NewServeMux() c := new(dns.Client) - for _, z := range zones { - srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) { + for zoneName, zones := range zones { + srv.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) { // Parse IP remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) @@ -159,80 +163,21 @@ func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *d return } - // Check ACL rules - if !checkACL(z.acl, aclList, ip) { + // find out what view to handle the request + zoneIndex := -1 + + for i, zone := range zones { + if (len(zone.acl) == 0 && zoneIndex == -1) || checkACL(zone.acl, aclList, ip) { + zoneIndex = i + } + } + + if zoneIndex == -1 { rcodeRequest(w, r, dns.RcodeRefused) return } - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - - // maybe only support one question per query like most servers do it ??? - for _, q := range r.Question { - rrs := z.rr[q.Qtype] - - // Handle ANY - if q.Qtype == dns.TypeANY { - for _, rrType := range anyRecordTypes { - m.Answer = append(m.Answer, z.rr[rrType][q.Name]...) - } - } else { - // Handle any other type - m.Answer = append(m.Answer, rrs[q.Name]...) - - // Check for wildcard - if len(m.Answer) == 0 { - parts := dns.SplitDomainName(q.Name)[1:] - searchDomain := "*." + dns.Fqdn(strings.Join(parts, ".")) - foundDomain := rrs[searchDomain] - for _, rr := range foundDomain { - newRR := rr - newRR.Header().Name = q.Name - m.Answer = append(m.Answer, newRR) - } - } - } - - // Handle extras - switch q.Qtype { - // Dont handle extra stuff when answering ANY request - // case dns.TypeANY: - // fallthrough - case dns.TypeMX: - // Resolve MX domains - for _, mxRR := range m.Answer { - if t, ok := mxRR.(*dns.MX); ok { - m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Mx]...) - m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Mx]...) - } - } - case dns.TypeA, dns.TypeAAAA: - if len(m.Answer) == 0 { - // no A or AAAA found. Look for CNAME - m.Answer = append(m.Answer, z.rr[dns.TypeCNAME][q.Name]...) - if len(m.Answer) != 0 { - // Resolve CNAME - for _, nameRR := range m.Answer { - if t, ok := nameRR.(*dns.CNAME); ok { - m.Answer = append(m.Answer, z.rr[q.Qtype][t.Target]...) - } - } - } - } - case dns.TypeNS: - // Resove NS records - for _, nsRR := range m.Answer { - if t, ok := nsRR.(*dns.NS); ok { - m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Ns]...) - m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Ns]...) - } - } - } - } - - w.WriteMsg(m) + handleRequest(w, r, zones[zoneIndex]) }) } @@ -307,6 +252,77 @@ func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) { w.WriteMsg(m) } +func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) { + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + + // maybe only support one question per query like most servers do it ??? + for _, q := range r.Question { + rrs := zone.rr[q.Qtype] + + // Handle ANY + if q.Qtype == dns.TypeANY { + for _, rrType := range anyRecordTypes { + m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...) + } + } else { + // Handle any other type + m.Answer = append(m.Answer, rrs[q.Name]...) + + // Check for wildcard + if len(m.Answer) == 0 { + parts := dns.SplitDomainName(q.Name)[1:] + searchDomain := "*." + dns.Fqdn(strings.Join(parts, ".")) + foundDomain := rrs[searchDomain] + for _, rr := range foundDomain { + newRR := rr + newRR.Header().Name = q.Name + m.Answer = append(m.Answer, newRR) + } + } + } + + // Handle extras + switch q.Qtype { + // Dont handle extra stuff when answering ANY request + // case dns.TypeANY: + // fallthrough + case dns.TypeMX: + // Resolve MX domains + for _, mxRR := range m.Answer { + if t, ok := mxRR.(*dns.MX); ok { + m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...) + m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...) + } + } + case dns.TypeA, dns.TypeAAAA: + if len(m.Answer) == 0 { + // no A or AAAA found. Look for CNAME + m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...) + if len(m.Answer) != 0 { + // Resolve CNAME + for _, nameRR := range m.Answer { + if t, ok := nameRR.(*dns.CNAME); ok { + m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...) + } + } + } + } + case dns.TypeNS: + // Resove NS records + for _, nsRR := range m.Answer { + if t, ok := nsRR.(*dns.NS); ok { + m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...) + m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...) + } + } + } + } + + w.WriteMsg(m) +} + func main() { config, err := loadConfig() if err != nil {