cool-dns/coolDns.go

206 lines
4.2 KiB
Go
Raw Normal View History

2020-12-21 21:43:07 +00:00
package main
import (
"io/ioutil"
"log"
"os"
"os/signal"
"strconv"
"syscall"
"github.com/miekg/dns"
"gopkg.in/yaml.v3"
)
type zone struct {
zone string
rr rrMap
acl []string
}
type rrMap map[uint16]map[string][]dns.RR
type config struct {
Zones []configZone `yaml:"zones"`
ACL []configACL `yaml:"acl"`
Forward configForward `yaml:"forward"`
}
type configForward struct {
ACL []string `yaml:"acl"`
}
type configACL struct {
Name string `yaml:"name"`
IPRange string `yaml:"range"`
}
type configZone struct {
Zone string `yaml:"zone"`
File string `yaml:"file"`
ACL []string `yaml:"acl"`
}
func loadConfig() (*config, error) {
file, err := ioutil.ReadFile("config.yml")
if err != nil {
return nil, err
}
var loadedConfig config
err = yaml.Unmarshal(file, &loadedConfig)
if err != nil {
return nil, err
}
return &loadedConfig, nil
}
func loadZones(configZones []configZone) ([]zone, error) {
zones := make([]zone, 0)
for _, z := range configZones {
rrs, err := loadZonefile(z.File)
if err != nil {
return nil, err
}
zones = append(zones, zone{
zone: z.Zone,
rr: createRRMap(rrs),
})
log.Printf("Loaded zone %s\n", z.Zone)
}
return zones, nil
}
func createRRMap(rrs []dns.RR) rrMap {
rrMap := make(rrMap)
for _, rr := range rrs {
if rrMap[rr.Header().Rrtype] == nil {
rrMap[rr.Header().Rrtype] = make(map[string][]dns.RR)
}
if rrMap[rr.Header().Rrtype][rr.Header().Name] == nil {
rrMap[rr.Header().Rrtype][rr.Header().Name] = make([]dns.RR, 0)
}
rrMap[rr.Header().Rrtype][rr.Header().Name] = append(rrMap[rr.Header().Rrtype][rr.Header().Name], rr)
}
return rrMap
}
func loadZonefile(filepath string) ([]dns.RR, error) {
file, err := os.Open(filepath)
if err != nil {
return nil, err
}
parser := dns.NewZoneParser(file, "", "")
var rrs = make([]dns.RR, 0)
for rr, ok := parser.Next(); ok; rr, ok = parser.Next() {
rrs = append(rrs, rr)
}
if err := parser.Err(); err != nil {
log.Println(err)
}
return rrs, nil
}
func createServer(zones []zone, config config) *dns.ServeMux {
srv := dns.NewServeMux()
for _, z := range zones {
srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
// maybe only support one question per query like most servers do it ???
for _, q := range r.Question {
rr := z.rr[q.Qtype]
m.Answer = append(m.Answer, rr[q.Name]...)
// Handle extras
switch q.Qtype {
case dns.TypeMX:
// Resolve MX domains
for _, mxRR := range rr[q.Name] {
if t, ok := mxRR.(*dns.MX); ok {
m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Mx]...)
m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Mx]...)
}
}
case dns.TypeA, dns.TypeAAAA:
if len(m.Answer) == 0 {
// no A or AAAA found. Look for CNAME
m.Answer = append(m.Answer, z.rr[dns.TypeCNAME][q.Name]...)
if len(m.Answer) != 0 {
// Resolve CNAME
for _, nameRR := range m.Answer {
if t, ok := nameRR.(*dns.CNAME); ok {
m.Answer = append(m.Answer, z.rr[q.Qtype][t.Target]...)
}
}
}
}
case dns.TypeNS:
// Resove NS records
for _, nsRR := range rr[q.Name] {
if t, ok := nsRR.(*dns.NS); ok {
m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Ns]...)
m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Ns]...)
}
}
}
}
w.WriteMsg(m)
})
}
return srv
}
func listenAndServer(server *dns.ServeMux) {
go func() {
if err := dns.ListenAndServe(":"+strconv.Itoa(8053), "udp", server); err != nil {
log.Fatalf("Failed to set udp listener %s\n", err.Error())
}
}()
go func() {
if err := dns.ListenAndServe(":"+strconv.Itoa(8053), "tcp", server); err != nil {
log.Fatalf("Failed to set tcp listener %s\n", err.Error())
}
}()
sig := make(chan os.Signal)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
s := <-sig
log.Printf("Signal (%v) received, stopping\n", s)
os.Exit(0)
}
func main() {
config, err := loadConfig()
if err != nil {
log.Fatalf("Failed to load config: %s", err.Error())
}
zones, err := loadZones(config.Zones)
if err != nil {
log.Fatalf("Failed to load zones: %s", err.Error())
}
server := createServer(zones, *config)
listenAndServer(server)
}