diff --git a/internal/authoritative.go b/internal/authoritative.go new file mode 100644 index 0000000..1e0a313 --- /dev/null +++ b/internal/authoritative.go @@ -0,0 +1,108 @@ +package cooldns + +import ( + "strings" + + "github.com/miekg/dns" +) + +// All record types to send when a ANY request is send +var anyRecordTypes = []uint16{ + dns.TypeSOA, + dns.TypeA, + dns.TypeAAAA, + dns.TypeNS, + dns.TypeCNAME, + dns.TypeMX, + dns.TypeTXT, + dns.TypeSRV, + dns.TypeCAA, +} + +// handleRequest find the right RR(s) in the view and send them back +func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) { + m := new(dns.Msg) + m.SetReply(r) + m.Authoritative = true + + // Only support one question per query because all the other server also does that + if len(r.Question) != 1 { + rcodeRequest(w, r, dns.RcodeServerFailure) + } + + q := r.Question[0] + + rrs := zone.rr[q.Qtype] + + // Handle ANY + if q.Qtype == dns.TypeANY { + for _, rrType := range anyRecordTypes { + m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...) + } + } else { + // Handle any other type + m.Answer = append(m.Answer, rrs[q.Name]...) + + // Check for wildcard + if len(m.Answer) == 0 { + parts := dns.SplitDomainName(q.Name)[1:] + searchDomain := "*." + dns.Fqdn(strings.Join(parts, ".")) + foundDomain := rrs[searchDomain] + for _, rr := range foundDomain { + newRR := rr + newRR.Header().Name = q.Name + m.Answer = append(m.Answer, newRR) + } + } + } + + // Handle extras + switch q.Qtype { + // Dont handle extra stuff when answering ANY request + // case dns.TypeANY: + // fallthrough + case dns.TypeMX: + // Resolve MX domains + for _, mxRR := range m.Answer { + if t, ok := mxRR.(*dns.MX); ok { + m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...) + m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...) + } + } + case dns.TypeA, dns.TypeAAAA: + if len(m.Answer) == 0 { + // no A or AAAA found. Look for CNAME + m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...) + if len(m.Answer) != 0 { + // Resolve CNAME + for _, nameRR := range m.Answer { + if t, ok := nameRR.(*dns.CNAME); ok { + m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...) + } + } + } + } + case dns.TypeNS: + // Resove NS records + for _, nsRR := range m.Answer { + if t, ok := nsRR.(*dns.NS); ok { + m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...) + m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...) + } + } + case dns.TypeCNAME: + // Resolve CNAME + for _, cnameRR := range m.Answer { + if t, ok := cnameRR.(*dns.CNAME); ok { + m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Target]...) + m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Target]...) + } + } + } + + if len(m.Answer) == 0 { + m.SetRcode(m, dns.RcodeNameError) + } + + w.WriteMsg(m) +} diff --git a/internal/cooldns.go b/internal/cooldns.go index f72a676..46ec9f3 100644 --- a/internal/cooldns.go +++ b/internal/cooldns.go @@ -5,7 +5,6 @@ import ( "net" "os" "path/filepath" - "strings" "github.com/miekg/dns" ) @@ -19,19 +18,6 @@ type zoneMap map[string][]zoneView type rrMap map[uint16]map[string][]dns.RR -// All record types to send when a ANY request is send -var anyRecordTypes = []uint16{ - dns.TypeSOA, - dns.TypeA, - dns.TypeAAAA, - dns.TypeNS, - dns.TypeCNAME, - dns.TypeMX, - dns.TypeTXT, - dns.TypeSRV, - dns.TypeCAA, -} - // Start starts cooldns func Start(configPath string) { config, err := loadConfig(configPath) @@ -203,91 +189,3 @@ func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) { m.SetRcode(r, rcode) w.WriteMsg(m) } - -// handleRequest find the right RR(s) in the view and send them back -func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) { - m := new(dns.Msg) - m.SetReply(r) - m.Authoritative = true - - // Only support one question per query because all the other server also does that - if len(r.Question) != 1 { - rcodeRequest(w, r, dns.RcodeServerFailure) - } - - q := r.Question[0] - - rrs := zone.rr[q.Qtype] - - // Handle ANY - if q.Qtype == dns.TypeANY { - for _, rrType := range anyRecordTypes { - m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...) - } - } else { - // Handle any other type - m.Answer = append(m.Answer, rrs[q.Name]...) - - // Check for wildcard - if len(m.Answer) == 0 { - parts := dns.SplitDomainName(q.Name)[1:] - searchDomain := "*." + dns.Fqdn(strings.Join(parts, ".")) - foundDomain := rrs[searchDomain] - for _, rr := range foundDomain { - newRR := rr - newRR.Header().Name = q.Name - m.Answer = append(m.Answer, newRR) - } - } - } - - // Handle extras - switch q.Qtype { - // Dont handle extra stuff when answering ANY request - // case dns.TypeANY: - // fallthrough - case dns.TypeMX: - // Resolve MX domains - for _, mxRR := range m.Answer { - if t, ok := mxRR.(*dns.MX); ok { - m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...) - m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...) - } - } - case dns.TypeA, dns.TypeAAAA: - if len(m.Answer) == 0 { - // no A or AAAA found. Look for CNAME - m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...) - if len(m.Answer) != 0 { - // Resolve CNAME - for _, nameRR := range m.Answer { - if t, ok := nameRR.(*dns.CNAME); ok { - m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...) - } - } - } - } - case dns.TypeNS: - // Resove NS records - for _, nsRR := range m.Answer { - if t, ok := nsRR.(*dns.NS); ok { - m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...) - m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...) - } - } - case dns.TypeCNAME: - // Resolve CNAME - for _, cnameRR := range m.Answer { - if t, ok := cnameRR.(*dns.CNAME); ok { - m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Target]...) - m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Target]...) - } - } - } - - if len(m.Answer) == 0 { - m.SetRcode(m, dns.RcodeNameError) - } - - w.WriteMsg(m) -}