implemented helper functions

This commit is contained in:
Niklas 2020-12-23 21:58:08 +01:00
parent b079e6988f
commit 4e3025c581

View File

@ -153,31 +153,17 @@ func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *d
// Parse IP // Parse IP
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
if err != nil { ip := net.ParseIP(remoteIP)
if err != nil && ip != nil {
log.Printf("Faild to parse remote IP WTF? :%s", err.Error()) log.Printf("Faild to parse remote IP WTF? :%s", err.Error())
return return
} }
ip := net.ParseIP(remoteIP)
// Check ACL rules // Check ACL rules
if len(z.acl) != 0 { if !checkACL(z.acl, aclList, ip) {
passed := false rcodeRequest(w, r, dns.RcodeRefused)
for _, rule := range z.acl {
if aclList[rule].Contains(ip) {
passed = true
}
}
if !passed {
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(m)
return return
} }
}
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
@ -261,32 +247,16 @@ func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *d
} }
// Check ACL rules // Check ACL rules
if len(config.Forward.ACL) != 0 { if !checkACL(config.Forward.ACL, aclList, ip) {
passed := false rcodeRequest(w, r, dns.RcodeRefused)
for _, rule := range config.Forward.ACL {
if aclList[rule].Contains(ip) {
passed = true
}
}
if !passed {
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, dns.RcodeRefused)
w.WriteMsg(m)
return return
} }
}
// Forward request // Forward request
in, _, err := c.Exchange(r, config.Forward.Server) in, _, err := c.Exchange(r, config.Forward.Server)
if err != nil { if err != nil {
m := new(dns.Msg) rcodeRequest(w, r, dns.RcodeServerFailure)
m.SetReply(r)
m.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(m)
return return
} }
@ -316,6 +286,27 @@ func listenAndServer(server *dns.ServeMux) {
os.Exit(0) os.Exit(0)
} }
func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool {
if len(alcRules) != 0 {
passed := false
for _, rule := range alcRules {
if aclList[rule].Contains(ip) {
passed = true
}
}
return passed
}
return true
}
func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) {
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, rcode)
w.WriteMsg(m)
}
func main() { func main() {
config, err := loadConfig() config, err := loadConfig()
if err != nil { if err != nil {