added acl

This commit is contained in:
Niklas 2020-12-22 19:07:39 +01:00
parent 99568644d1
commit f77c41c13d
2 changed files with 48 additions and 7 deletions

View File

@ -1,6 +1,4 @@
zones:
- zone: example.com.
file: zonefile.txt
- zone: example.com.
file: zonefile.txt
acl:
@ -8,7 +6,7 @@ zones:
acl:
- name: vpn
range: 10.0.0.0/24
cidr: 10.0.0.0/24
forward:
alc:

View File

@ -3,6 +3,7 @@ package main
import (
"io/ioutil"
"log"
"net"
"os"
"os/signal"
"strconv"
@ -31,8 +32,8 @@ type configForward struct {
}
type configACL struct {
Name string `yaml:"name"`
IPRange string `yaml:"range"`
Name string `yaml:"name"`
CIDR string `yaml:"cidr"`
}
type configZone struct {
@ -67,6 +68,7 @@ func loadZones(configZones []configZone) ([]zone, error) {
zones = append(zones, zone{
zone: z.Zone,
rr: createRRMap(rrs),
acl: z.ACL,
})
log.Printf("Loaded zone %s\n", z.Zone)
}
@ -112,11 +114,47 @@ func loadZonefile(filepath string) ([]dns.RR, error) {
return rrs, nil
}
func createServer(zones []zone, config config) *dns.ServeMux {
func createACLList(config []configACL) (map[string]*net.IPNet, error) {
acls := make(map[string]*net.IPNet)
for _, aclRule := range config {
_, mask, err := net.ParseCIDR(aclRule.CIDR)
if err != nil {
return nil, err
}
acls[aclRule.Name] = mask
}
return acls, nil
}
func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *dns.ServeMux {
srv := dns.NewServeMux()
for _, z := range zones {
srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) {
remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String())
ip := net.ParseIP(remoteIP)
if len(z.acl) != 0 {
passed := false
for _, rule := range z.acl {
if aclList[rule].Contains(ip) {
passed = true
}
}
if !passed {
m := new(dns.Msg)
m.SetReply(r)
w.WriteMsg(m)
return
}
}
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
@ -199,7 +237,12 @@ func main() {
log.Fatalf("Failed to load zones: %s", err.Error())
}
server := createServer(zones, *config)
aclList, err := createACLList(config.ACL)
if err != nil {
log.Fatalf("Failed to parse ACL rules: %s", err.Error())
}
server := createServer(zones, *config, aclList)
listenAndServer(server)
}