diff --git a/coolDns.go b/coolDns.go index a9e47d8..970d8db 100644 --- a/coolDns.go +++ b/coolDns.go @@ -278,65 +278,69 @@ func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) { m.SetReply(r) m.Authoritative = true - // maybe only support one question per query like most servers do it ??? - for _, q := range r.Question { - rrs := zone.rr[q.Qtype] + // Only support one question per query because all the other server also does that + if len(r.Question) != 1 { + rcodeRequest(w, r, dns.RcodeServerFailure) + } - // Handle ANY - if q.Qtype == dns.TypeANY { - for _, rrType := range anyRecordTypes { - m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...) - } - } else { - // Handle any other type - m.Answer = append(m.Answer, rrs[q.Name]...) + q := r.Question[0] - // Check for wildcard - if len(m.Answer) == 0 { - parts := dns.SplitDomainName(q.Name)[1:] - searchDomain := "*." + dns.Fqdn(strings.Join(parts, ".")) - foundDomain := rrs[searchDomain] - for _, rr := range foundDomain { - newRR := rr - newRR.Header().Name = q.Name - m.Answer = append(m.Answer, newRR) - } + rrs := zone.rr[q.Qtype] + + // Handle ANY + if q.Qtype == dns.TypeANY { + for _, rrType := range anyRecordTypes { + m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...) + } + } else { + // Handle any other type + m.Answer = append(m.Answer, rrs[q.Name]...) + + // Check for wildcard + if len(m.Answer) == 0 { + parts := dns.SplitDomainName(q.Name)[1:] + searchDomain := "*." + dns.Fqdn(strings.Join(parts, ".")) + foundDomain := rrs[searchDomain] + for _, rr := range foundDomain { + newRR := rr + newRR.Header().Name = q.Name + m.Answer = append(m.Answer, newRR) } } + } - // Handle extras - switch q.Qtype { - // Dont handle extra stuff when answering ANY request - // case dns.TypeANY: - // fallthrough - case dns.TypeMX: - // Resolve MX domains - for _, mxRR := range m.Answer { - if t, ok := mxRR.(*dns.MX); ok { - m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...) - m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...) - } + // Handle extras + switch q.Qtype { + // Dont handle extra stuff when answering ANY request + // case dns.TypeANY: + // fallthrough + case dns.TypeMX: + // Resolve MX domains + for _, mxRR := range m.Answer { + if t, ok := mxRR.(*dns.MX); ok { + m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...) + m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...) } - case dns.TypeA, dns.TypeAAAA: - if len(m.Answer) == 0 { - // no A or AAAA found. Look for CNAME - m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...) - if len(m.Answer) != 0 { - // Resolve CNAME - for _, nameRR := range m.Answer { - if t, ok := nameRR.(*dns.CNAME); ok { - m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...) - } + } + case dns.TypeA, dns.TypeAAAA: + if len(m.Answer) == 0 { + // no A or AAAA found. Look for CNAME + m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...) + if len(m.Answer) != 0 { + // Resolve CNAME + for _, nameRR := range m.Answer { + if t, ok := nameRR.(*dns.CNAME); ok { + m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...) } } } - case dns.TypeNS: - // Resove NS records - for _, nsRR := range m.Answer { - if t, ok := nsRR.(*dns.NS); ok { - m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...) - m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...) - } + } + case dns.TypeNS: + // Resove NS records + for _, nsRR := range m.Answer { + if t, ok := nsRR.(*dns.NS); ok { + m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...) + m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...) } } }