206 lines
4.2 KiB
Go
206 lines
4.2 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"io/ioutil"
|
||
|
"log"
|
||
|
"os"
|
||
|
"os/signal"
|
||
|
"strconv"
|
||
|
"syscall"
|
||
|
|
||
|
"github.com/miekg/dns"
|
||
|
"gopkg.in/yaml.v3"
|
||
|
)
|
||
|
|
||
|
type zone struct {
|
||
|
zone string
|
||
|
rr rrMap
|
||
|
acl []string
|
||
|
}
|
||
|
|
||
|
type rrMap map[uint16]map[string][]dns.RR
|
||
|
|
||
|
type config struct {
|
||
|
Zones []configZone `yaml:"zones"`
|
||
|
ACL []configACL `yaml:"acl"`
|
||
|
Forward configForward `yaml:"forward"`
|
||
|
}
|
||
|
|
||
|
type configForward struct {
|
||
|
ACL []string `yaml:"acl"`
|
||
|
}
|
||
|
|
||
|
type configACL struct {
|
||
|
Name string `yaml:"name"`
|
||
|
IPRange string `yaml:"range"`
|
||
|
}
|
||
|
|
||
|
type configZone struct {
|
||
|
Zone string `yaml:"zone"`
|
||
|
File string `yaml:"file"`
|
||
|
ACL []string `yaml:"acl"`
|
||
|
}
|
||
|
|
||
|
func loadConfig() (*config, error) {
|
||
|
file, err := ioutil.ReadFile("config.yml")
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
var loadedConfig config
|
||
|
|
||
|
err = yaml.Unmarshal(file, &loadedConfig)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return &loadedConfig, nil
|
||
|
}
|
||
|
|
||
|
func loadZones(configZones []configZone) ([]zone, error) {
|
||
|
zones := make([]zone, 0)
|
||
|
for _, z := range configZones {
|
||
|
rrs, err := loadZonefile(z.File)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
zones = append(zones, zone{
|
||
|
zone: z.Zone,
|
||
|
rr: createRRMap(rrs),
|
||
|
})
|
||
|
log.Printf("Loaded zone %s\n", z.Zone)
|
||
|
}
|
||
|
|
||
|
return zones, nil
|
||
|
}
|
||
|
|
||
|
func createRRMap(rrs []dns.RR) rrMap {
|
||
|
rrMap := make(rrMap)
|
||
|
for _, rr := range rrs {
|
||
|
if rrMap[rr.Header().Rrtype] == nil {
|
||
|
rrMap[rr.Header().Rrtype] = make(map[string][]dns.RR)
|
||
|
}
|
||
|
|
||
|
if rrMap[rr.Header().Rrtype][rr.Header().Name] == nil {
|
||
|
rrMap[rr.Header().Rrtype][rr.Header().Name] = make([]dns.RR, 0)
|
||
|
}
|
||
|
rrMap[rr.Header().Rrtype][rr.Header().Name] = append(rrMap[rr.Header().Rrtype][rr.Header().Name], rr)
|
||
|
}
|
||
|
|
||
|
return rrMap
|
||
|
}
|
||
|
|
||
|
func loadZonefile(filepath string) ([]dns.RR, error) {
|
||
|
file, err := os.Open(filepath)
|
||
|
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
parser := dns.NewZoneParser(file, "", "")
|
||
|
|
||
|
var rrs = make([]dns.RR, 0)
|
||
|
|
||
|
for rr, ok := parser.Next(); ok; rr, ok = parser.Next() {
|
||
|
rrs = append(rrs, rr)
|
||
|
}
|
||
|
|
||
|
if err := parser.Err(); err != nil {
|
||
|
log.Println(err)
|
||
|
}
|
||
|
|
||
|
return rrs, nil
|
||
|
}
|
||
|
|
||
|
func createServer(zones []zone, config config) *dns.ServeMux {
|
||
|
srv := dns.NewServeMux()
|
||
|
|
||
|
for _, z := range zones {
|
||
|
srv.HandleFunc(z.zone, func(w dns.ResponseWriter, r *dns.Msg) {
|
||
|
m := new(dns.Msg)
|
||
|
m.SetReply(r)
|
||
|
m.Authoritative = true
|
||
|
|
||
|
// maybe only support one question per query like most servers do it ???
|
||
|
for _, q := range r.Question {
|
||
|
rr := z.rr[q.Qtype]
|
||
|
|
||
|
m.Answer = append(m.Answer, rr[q.Name]...)
|
||
|
|
||
|
// Handle extras
|
||
|
switch q.Qtype {
|
||
|
case dns.TypeMX:
|
||
|
// Resolve MX domains
|
||
|
for _, mxRR := range rr[q.Name] {
|
||
|
if t, ok := mxRR.(*dns.MX); ok {
|
||
|
m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Mx]...)
|
||
|
m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Mx]...)
|
||
|
}
|
||
|
}
|
||
|
case dns.TypeA, dns.TypeAAAA:
|
||
|
if len(m.Answer) == 0 {
|
||
|
// no A or AAAA found. Look for CNAME
|
||
|
m.Answer = append(m.Answer, z.rr[dns.TypeCNAME][q.Name]...)
|
||
|
if len(m.Answer) != 0 {
|
||
|
// Resolve CNAME
|
||
|
for _, nameRR := range m.Answer {
|
||
|
if t, ok := nameRR.(*dns.CNAME); ok {
|
||
|
m.Answer = append(m.Answer, z.rr[q.Qtype][t.Target]...)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
case dns.TypeNS:
|
||
|
// Resove NS records
|
||
|
for _, nsRR := range rr[q.Name] {
|
||
|
if t, ok := nsRR.(*dns.NS); ok {
|
||
|
m.Extra = append(m.Extra, z.rr[dns.TypeA][t.Ns]...)
|
||
|
m.Extra = append(m.Extra, z.rr[dns.TypeAAAA][t.Ns]...)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
w.WriteMsg(m)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
return srv
|
||
|
}
|
||
|
|
||
|
func listenAndServer(server *dns.ServeMux) {
|
||
|
go func() {
|
||
|
if err := dns.ListenAndServe(":"+strconv.Itoa(8053), "udp", server); err != nil {
|
||
|
log.Fatalf("Failed to set udp listener %s\n", err.Error())
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
go func() {
|
||
|
if err := dns.ListenAndServe(":"+strconv.Itoa(8053), "tcp", server); err != nil {
|
||
|
log.Fatalf("Failed to set tcp listener %s\n", err.Error())
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
sig := make(chan os.Signal)
|
||
|
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||
|
s := <-sig
|
||
|
log.Printf("Signal (%v) received, stopping\n", s)
|
||
|
os.Exit(0)
|
||
|
}
|
||
|
|
||
|
func main() {
|
||
|
config, err := loadConfig()
|
||
|
if err != nil {
|
||
|
log.Fatalf("Failed to load config: %s", err.Error())
|
||
|
}
|
||
|
|
||
|
zones, err := loadZones(config.Zones)
|
||
|
if err != nil {
|
||
|
log.Fatalf("Failed to load zones: %s", err.Error())
|
||
|
}
|
||
|
|
||
|
server := createServer(zones, *config)
|
||
|
|
||
|
listenAndServer(server)
|
||
|
}
|