added acl

This commit is contained in:
Niklas 2020-12-22 19:07:39 +01:00
parent 99568644d1
commit f77c41c13d
2 changed files with 48 additions and 7 deletions

View File

@ -1,6 +1,4 @@
zones: zones:
- zone: example.com.
file: zonefile.txt
- zone: example.com. - zone: example.com.
file: zonefile.txt file: zonefile.txt
acl: acl:
@ -8,7 +6,7 @@ zones:
acl: acl:
- name: vpn - name: vpn
range: 10.0.0.0/24 cidr: 10.0.0.0/24
forward: forward:
alc: alc:

View File

@ -3,6 +3,7 @@ package main
import ( import (
"io/ioutil" "io/ioutil"
"log" "log"
"net"
"os" "os"
"os/signal" "os/signal"
"strconv" "strconv"
@ -32,7 +33,7 @@ type configForward struct {
type configACL struct { type configACL struct {
Name string `yaml:"name"` Name string `yaml:"name"`
IPRange string `yaml:"range"` CIDR string `yaml:"cidr"`
} }
type configZone struct { type configZone struct {
@ -67,6 +68,7 @@ func loadZones(configZones []configZone) ([]zone, error) {
zones = append(zones, zone{ zones = append(zones, zone{
zone: z.Zone, zone: z.Zone,
rr: createRRMap(rrs), rr: createRRMap(rrs),
acl: z.ACL,
}) })
log.Printf("Loaded zone %s\n", z.Zone) log.Printf("Loaded zone %s\n", z.Zone)
} }
@ -112,11 +114,47 @@ func loadZonefile(filepath string) ([]dns.RR, error) {
return rrs, nil return rrs, nil
} }
func createServer(zones []zone, config config) *dns.ServeMux { func createACLList(config []configACL) (map[string]*net.IPNet, error) {
acls := make(map[string]*net.IPNet)
for _, aclRule := range config {
_, mask, err := net.ParseCIDR(aclRule.CIDR)
if err != nil {
return nil, err
}
acls[aclRule.Name] = mask
}
return acls, nil
}
func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *dns.ServeMux {
srv := dns.NewServeMux() srv := dns.NewServeMux()
for _, z := range zones { for _, z := range zones {
srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) { srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) {
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
ip := net.ParseIP(remoteIP)
if len(z.acl) != 0 {
passed := false
for _, rule := range z.acl {
if aclList[rule].Contains(ip) {
passed = true
}
}
if !passed {
m := new(dns.Msg)
m.SetReply(r)
w.WriteMsg(m)
return
}
}
m := new(dns.Msg) m := new(dns.Msg)
m.SetReply(r) m.SetReply(r)
m.Authoritative = true m.Authoritative = true
@ -199,7 +237,12 @@ func main() {
log.Fatalf("Failed to load zones: %s", err.Error()) log.Fatalf("Failed to load zones: %s", err.Error())
} }
server := createServer(zones, *config) aclList, err := createACLList(config.ACL)
if err != nil {
log.Fatalf("Failed to parse ACL rules: %s", err.Error())
}
server := createServer(zones, *config, aclList)
listenAndServer(server) listenAndServer(server)
} }