implemented diffeent zones bases on ip

This commit is contained in:
Niklas 2020-12-25 14:43:14 +01:00
parent 4e3025c581
commit 7d0e8e4b0d
2 changed files with 108 additions and 83 deletions

View File

@ -2,9 +2,18 @@ zones:
- zone: example.com. - zone: example.com.
file: zonefile.txt file: zonefile.txt
- zone: example.com.
file: zonefile2.txt
acl:
- lan
acl: acl:
- name: vpn - name: vpn
cidr: 10.0.0.0/24 cidr: 10.0.0.0/24
- name: lan
cidr: 192.168.0.0/16
- name: local
cidr: 127.0.0.1/32
forward: forward:
acl: acl:

View File

@ -14,12 +14,13 @@ import (
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
type zone struct { type zoneView struct {
zone string
rr rrMap rr rrMap
acl []string acl []string
} }
type zoneMap map[string][]zoneView
type rrMap map[uint16]map[string][]dns.RR type rrMap map[uint16]map[string][]dns.RR
type config struct { type config struct {
@ -72,18 +73,21 @@ func loadConfig() (*config, error) {
return &loadedConfig, nil return &loadedConfig, nil
} }
func loadZones(configZones []configZone) ([]zone, error) { func loadZones(configZones []configZone) (zoneMap, error) {
zones := make([]zone, 0) zones := make(zoneMap)
for _, z := range configZones { for _, z := range configZones {
rrs, err := loadZonefile(z.File) rrs, err := loadZonefile(z.File)
if err != nil { if err != nil {
return nil, err return nil, err
} }
zones = append(zones, zone{ if zones[z.Zone] == nil {
zone: z.Zone, zones[z.Zone] = make([]zoneView, 0)
}
zones[z.Zone] = append(zones[z.Zone], zoneView{
rr: createRRMap(rrs), rr: createRRMap(rrs),
acl: z.ACL, acl: z.ACL,
}) })
log.Printf("Loaded zone %s\n", z.Zone) log.Printf("Loaded zone %s\n", z.Zone)
} }
@ -144,12 +148,12 @@ func createACLList(config []configACL) (map[string]*net.IPNet, error) {
return acls, nil return acls, nil
} }
func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *dns.ServeMux { func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet) *dns.ServeMux {
srv := dns.NewServeMux() srv := dns.NewServeMux()
c := new(dns.Client) c := new(dns.Client)
for _, z := range zones { for zoneName, zones := range zones {
srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) { srv.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) {
// Parse IP // Parse IP
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
@ -159,80 +163,21 @@ func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *d
return return
} }
// Check ACL rules // find out what view to handle the request
if !checkACL(z.acl, aclList, ip) { zoneIndex := -1
for i, zone := range zones {
if (len(zone.acl) == 0 && zoneIndex == -1) || checkACL(zone.acl, aclList, ip) {
zoneIndex = i
}
}
if zoneIndex == -1 {
rcodeRequest(w, r, dns.RcodeRefused) rcodeRequest(w, r, dns.RcodeRefused)
return return
} }
m := new(dns.Msg) handleRequest(w, r, zones[zoneIndex])
m.SetReply(r)
m.Authoritative = true
// maybe only support one question per query like most servers do it ???
for _, q := range r.Question {
rrs := z.rr[q.Qtype]
// Handle ANY
if q.Qtype == dns.TypeANY {
for _, rrType := range anyRecordTypes {
m.Answer = append(m.Answer, z.rr[rrType][q.Name]...)
}
} else {
// Handle any other type
m.Answer = append(m.Answer, rrs[q.Name]...)
// Check for wildcard
if len(m.Answer) == 0 {
parts := dns.SplitDomainName(q.Name)[1:]
searchDomain := "*." + dns.Fqdn(strings.Join(parts, "."))
foundDomain := rrs[searchDomain]
for _, rr := range foundDomain {
newRR := rr
newRR.Header().Name = q.Name
m.Answer = append(m.Answer, newRR)
}
}
}
// Handle extras
switch q.Qtype {
// Dont handle extra stuff when answering ANY request
// case dns.TypeANY:
// fallthrough
case dns.TypeMX:
// Resolve MX domains
for _, mxRR := range m.Answer {
if t, ok := mxRR.(*dns.MX); ok {
m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Mx]...)
m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Mx]...)
}
}
case dns.TypeA, dns.TypeAAAA:
if len(m.Answer) == 0 {
// no A or AAAA found. Look for CNAME
m.Answer = append(m.Answer, z.rr[dns.TypeCNAME][q.Name]...)
if len(m.Answer) != 0 {
// Resolve CNAME
for _, nameRR := range m.Answer {
if t, ok := nameRR.(*dns.CNAME); ok {
m.Answer = append(m.Answer, z.rr[q.Qtype][t.Target]...)
}
}
}
}
case dns.TypeNS:
// Resove NS records
for _, nsRR := range m.Answer {
if t, ok := nsRR.(*dns.NS); ok {
m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Ns]...)
m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Ns]...)
}
}
}
}
w.WriteMsg(m)
}) })
} }
@ -307,6 +252,77 @@ func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) {
w.WriteMsg(m) w.WriteMsg(m)
} }
func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
// maybe only support one question per query like most servers do it ???
for _, q := range r.Question {
rrs := zone.rr[q.Qtype]
// Handle ANY
if q.Qtype == dns.TypeANY {
for _, rrType := range anyRecordTypes {
m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...)
}
} else {
// Handle any other type
m.Answer = append(m.Answer, rrs[q.Name]...)
// Check for wildcard
if len(m.Answer) == 0 {
parts := dns.SplitDomainName(q.Name)[1:]
searchDomain := "*." + dns.Fqdn(strings.Join(parts, "."))
foundDomain := rrs[searchDomain]
for _, rr := range foundDomain {
newRR := rr
newRR.Header().Name = q.Name
m.Answer = append(m.Answer, newRR)
}
}
}
// Handle extras
switch q.Qtype {
// Dont handle extra stuff when answering ANY request
// case dns.TypeANY:
// fallthrough
case dns.TypeMX:
// Resolve MX domains
for _, mxRR := range m.Answer {
if t, ok := mxRR.(*dns.MX); ok {
m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...)
m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...)
}
}
case dns.TypeA, dns.TypeAAAA:
if len(m.Answer) == 0 {
// no A or AAAA found. Look for CNAME
m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...)
if len(m.Answer) != 0 {
// Resolve CNAME
for _, nameRR := range m.Answer {
if t, ok := nameRR.(*dns.CNAME); ok {
m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...)
}
}
}
}
case dns.TypeNS:
// Resove NS records
for _, nsRR := range m.Answer {
if t, ok := nsRR.(*dns.NS); ok {
m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...)
m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...)
}
}
}
}
w.WriteMsg(m)
}
func main() { func main() {
config, err := loadConfig() config, err := loadConfig()
if err != nil { if err != nil {