diff --git a/config.yml b/config.yml index b2c5106..bfdc939 100644 --- a/config.yml +++ b/config.yml @@ -1,6 +1,4 @@ zones: -- zone: example.com. - file: zonefile.txt - zone: example.com. file: zonefile.txt acl: @@ -8,7 +6,7 @@ zones: acl: - name: vpn - range: 10.0.0.0/24 + cidr: 10.0.0.0/24 forward: alc: diff --git a/coolDns.go b/coolDns.go index a9461df..1d4b429 100644 --- a/coolDns.go +++ b/coolDns.go @@ -3,6 +3,7 @@ package main import ( "io/ioutil" "log" + "net" "os" "os/signal" "strconv" @@ -31,8 +32,8 @@ type configForward struct { } type configACL struct { - Name string `yaml:"name"` - IPRange string `yaml:"range"` + Name string `yaml:"name"` + CIDR string `yaml:"cidr"` } type configZone struct { @@ -67,6 +68,7 @@ func loadZones(configZones []configZone) ([]zone, error) { zones = append(zones, zone{ zone: z.Zone, rr: createRRMap(rrs), + acl: z.ACL, }) log.Printf("Loaded zone %s\n", z.Zone) } @@ -112,11 +114,47 @@ func loadZonefile(filepath string) ([]dns.RR, error) { return rrs, nil } -func createServer(zones []zone, config config) *dns.ServeMux { +func createACLList(config []configACL) (map[string]*net.IPNet, error) { + acls := make(map[string]*net.IPNet) + + for _, aclRule := range config { + _, mask, err := net.ParseCIDR(aclRule.CIDR) + + if err != nil { + return nil, err + } + + acls[aclRule.Name] = mask + } + + return acls, nil +} + +func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *dns.ServeMux { srv := dns.NewServeMux() for _, z := range zones { srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) { + + remoteIP, _, _ := net.SplitHostPort(w.RemoteAddr().String()) + ip := net.ParseIP(remoteIP) + if len(z.acl) != 0 { + passed := false + for _, rule := range z.acl { + + if aclList[rule].Contains(ip) { + passed = true + } + } + + if !passed { + m := new(dns.Msg) + m.SetReply(r) + w.WriteMsg(m) + return + } + } + m := new(dns.Msg) m.SetReply(r) m.Authoritative = true @@ -199,7 +237,12 @@ func main() { log.Fatalf("Failed to load zones: %s", err.Error()) } - server := createServer(zones, *config) + aclList, err := createACLList(config.ACL) + if err != nil { + log.Fatalf("Failed to parse ACL rules: %s", err.Error()) + } + + server := createServer(zones, *config, aclList) listenAndServer(server) }