diff --git a/ResponseWrapper.go b/ResponseWrapper.go index 75f1fa4..c2edc46 100644 --- a/ResponseWrapper.go +++ b/ResponseWrapper.go @@ -3,12 +3,9 @@ package override import ( "net" - clog "github.com/coredns/coredns/plugin/pkg/log" "github.com/miekg/dns" ) -var log = clog.NewWithPlugin("overide") - type Rule struct { Search net.IP Override net.IP @@ -22,23 +19,31 @@ type ResponseWrapper struct { func (r *ResponseWrapper) WriteMsg(res *dns.Msg) error { for _, r := range r.Rules { for _, rr := range res.Answer { - if rr.Header().Rrtype == dns.TypeA { - a := rr.(*dns.A) - if a.A.Equal(r.Search) { - a.A = r.Override - } - } else if rr.Header().Rrtype == dns.TypeAAAA { - a := rr.(*dns.AAAA) - if a.AAAA.Equal(r.Search) { - a.AAAA = r.Override - } - } + overideRR(r, rr) + } + + for _, rr := range res.Extra { + overideRR(r, rr) } } return r.ResponseWriter.WriteMsg(res) } +func overideRR(r Rule, rr dns.RR) { + if rr.Header().Rrtype == dns.TypeA { + a := rr.(*dns.A) + if a.A.Equal(r.Search) { + a.A = r.Override + } + } else if rr.Header().Rrtype == dns.TypeAAAA { + a := rr.(*dns.AAAA) + if a.AAAA.Equal(r.Search) { + a.AAAA = r.Override + } + } +} + func (r *ResponseWrapper) Write(buf []byte) (int, error) { log.Warning("ResponseWrapper called with Write: not ensuring overide") return r.ResponseWriter.Write(buf)