From b079e6988f8b58b0eb12c322dd5b23aeecbeb011 Mon Sep 17 00:00:00 2001 From: Niklas Date: Wed, 23 Dec 2020 21:44:33 +0100 Subject: [PATCH] implemented forwarding --- config.yml | 7 ++----- coolDns.go | 47 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/config.yml b/config.yml index bfdc939..f490914 100644 --- a/config.yml +++ b/config.yml @@ -1,15 +1,12 @@ zones: - zone: example.com. file: zonefile.txt - acl: - - vpn acl: - name: vpn cidr: 10.0.0.0/24 forward: - alc: + acl: - vpn - - + server: "8.8.8.8:53" diff --git a/coolDns.go b/coolDns.go index 86178b4..c4024fa 100644 --- a/coolDns.go +++ b/coolDns.go @@ -29,7 +29,8 @@ type config struct { } type configForward struct { - ACL []string `yaml:"acl"` + ACL []string `yaml:"acl"` + Server string `yaml:"server"` } type configACL struct { @@ -145,6 +146,7 @@ func createACLList(config []configACL) (map[string]*net.IPNet, error) { func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *dns.ServeMux { srv := dns.NewServeMux() + c := new(dns.Client) for _, z := range zones { srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) { @@ -248,6 +250,49 @@ func createServer(zones []zone, config config, aclList map[string]*net.IPNet) *d }) } + // Handle any other request + srv.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { + remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String()) + ip := net.ParseIP(remoteIP) + + if err != nil && ip != nil { + log.Printf("Faild to parse remote IP WTF? :%s", err.Error()) + return + } + + // Check ACL rules + if len(config.Forward.ACL) != 0 { + passed := false + for _, rule := range config.Forward.ACL { + + if aclList[rule].Contains(ip) { + passed = true + } + } + + if !passed { + m := new(dns.Msg) + m.SetReply(r) + m.SetRcode(r, dns.RcodeRefused) + w.WriteMsg(m) + return + } + } + + // Forward request + in, _, err := c.Exchange(r, config.Forward.Server) + + if err != nil { + m := new(dns.Msg) + m.SetReply(r) + m.SetRcode(r, dns.RcodeServerFailure) + w.WriteMsg(m) + return + } + + w.WriteMsg(in) + }) + return srv }