diff --git a/coolDns.go b/coolDns.go index c4024fa..447f883 100644 --- a/coolDns.go +++ b/coolDns.go @@ -153,30 +153,16 @@ func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *d // Parse IP remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) - if err != nil { + ip := net.ParseIP(remoteIP) + if err != nil && ip != nil { log.Printf("Faild to parse remote IP WTF? :%s", err.Error()) return } - ip := net.ParseIP(remoteIP) - // Check ACL rules - if len(z.acl) != 0 { - passed := false - for _, rule := range z.acl { - - if aclList[rule].Contains(ip) { - passed = true - } - } - - if !passed { - m := new(dns.Msg) - m.SetReply(r) - m.SetRcode(r, dns.RcodeRefused) - w.WriteMsg(m) - return - } + if !checkACL(z.acl, aclList, ip) { + rcodeRequest(w, r, dns.RcodeRefused) + return } m := new(dns.Msg) @@ -261,32 +247,16 @@ func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *d } // Check ACL rules - if len(config.Forward.ACL) != 0 { - passed := false - for _, rule := range config.Forward.ACL { - - if aclList[rule].Contains(ip) { - passed = true - } - } - - if !passed { - m := new(dns.Msg) - m.SetReply(r) - m.SetRcode(r, dns.RcodeRefused) - w.WriteMsg(m) - return - } + if !checkACL(config.Forward.ACL, aclList, ip) { + rcodeRequest(w, r, dns.RcodeRefused) + return } // Forward request in, _, err := c.Exchange(r, config.Forward.Server) if err != nil { - m := new(dns.Msg) - m.SetReply(r) - m.SetRcode(r, dns.RcodeServerFailure) - w.WriteMsg(m) + rcodeRequest(w, r, dns.RcodeServerFailure) return } @@ -316,6 +286,27 @@ func listenAndServer(server *dns.ServeMux) { os.Exit(0) } +func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool { + if len(alcRules) != 0 { + passed := false + for _, rule := range alcRules { + + if aclList[rule].Contains(ip) { + passed = true + } + } + return passed + } + return true +} + +func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) { + m := new(dns.Msg) + m.SetReply(r) + m.SetRcode(r, rcode) + w.WriteMsg(m) +} + func main() { config, err := loadConfig() if err != nil {