cool-dns/coolDns.go

432 lines
9.6 KiB
Go
Raw Normal View History

2020-12-21 21:43:07 +00:00
package main
import (
2020-12-26 20:31:50 +00:00
"flag"
2020-12-21 21:43:07 +00:00
"io/ioutil"
"log"
2020-12-22 18:07:39 +00:00
"net"
2020-12-21 21:43:07 +00:00
"os"
"os/signal"
2021-01-08 18:12:58 +00:00
"path/filepath"
2020-12-22 21:13:32 +00:00
"strings"
2020-12-21 21:43:07 +00:00
"syscall"
"github.com/miekg/dns"
"gopkg.in/yaml.v3"
)
2020-12-25 13:43:14 +00:00
type zoneView struct {
rr rrMap
acl []string
2020-12-21 21:43:07 +00:00
}
2020-12-25 13:43:14 +00:00
type zoneMap map[string][]zoneView
2020-12-21 21:43:07 +00:00
type rrMap map[uint16]map[string][]dns.RR
2021-01-06 14:53:58 +00:00
// config format of the config file
2020-12-21 21:43:07 +00:00
type config struct {
2020-12-27 21:13:13 +00:00
Zones []configZone `yaml:"zones"`
ACL []configACL `yaml:"acl"`
Forward configForward `yaml:"forward"`
Address string `yaml:"address"`
Blacklist []configBlacklist `yaml:"blacklist"`
2020-12-30 20:59:33 +00:00
TLS configTLS `yaml:"tls"`
2021-01-08 15:08:57 +00:00
Lego configLego `yaml:"lego"`
2020-12-21 21:43:07 +00:00
}
type configForward struct {
2020-12-23 20:44:33 +00:00
ACL []string `yaml:"acl"`
Server string `yaml:"server"`
2020-12-21 21:43:07 +00:00
}
type configACL struct {
2020-12-22 18:07:39 +00:00
Name string `yaml:"name"`
CIDR string `yaml:"cidr"`
2020-12-21 21:43:07 +00:00
}
type configZone struct {
Zone string `yaml:"zone"`
File string `yaml:"file"`
ACL []string `yaml:"acl"`
}
2020-12-27 21:13:13 +00:00
type configBlacklist struct {
URL string `yaml:"url"`
Format string `yaml:"format"`
}
2020-12-30 20:59:33 +00:00
type configTLS struct {
Enable bool `yaml:"enable"`
Address string `yaml:"address"`
Cert string `yaml:"cert"`
Key string `yaml:"key"`
}
2021-01-06 14:53:58 +00:00
// All record types to send when a ANY request is send
2020-12-23 17:44:41 +00:00
var anyRecordTypes = []uint16{
dns.TypeSOA,
dns.TypeA,
dns.TypeAAAA,
dns.TypeNS,
dns.TypeCNAME,
dns.TypeMX,
dns.TypeTXT,
dns.TypeSRV,
dns.TypeCAA,
}
2020-12-26 20:31:50 +00:00
func loadConfig(configPath string) (*config, error) {
file, err := ioutil.ReadFile(configPath)
2020-12-21 21:43:07 +00:00
if err != nil {
return nil, err
}
var loadedConfig config
err = yaml.Unmarshal(file, &loadedConfig)
if err != nil {
return nil, err
}
return &loadedConfig, nil
}
2020-12-25 13:43:14 +00:00
func loadZones(configZones []configZone) (zoneMap, error) {
zones := make(zoneMap)
2020-12-21 21:43:07 +00:00
for _, z := range configZones {
2020-12-26 20:56:51 +00:00
rrs, err := loadZonefile(z.File, z.Zone)
2020-12-21 21:43:07 +00:00
if err != nil {
return nil, err
}
2020-12-25 13:43:14 +00:00
if zones[z.Zone] == nil {
zones[z.Zone] = make([]zoneView, 0)
}
zones[z.Zone] = append(zones[z.Zone], zoneView{
rr: createRRMap(rrs),
acl: z.ACL,
2020-12-21 21:43:07 +00:00
})
2020-12-25 13:43:14 +00:00
2020-12-21 21:43:07 +00:00
log.Printf("Loaded zone %s\n", z.Zone)
}
return zones, nil
}
2021-01-06 14:53:58 +00:00
// createRRMap order the rr into a structure that is more easy to use
2020-12-21 21:43:07 +00:00
func createRRMap(rrs []dns.RR) rrMap {
rrMap := make(rrMap)
for _, rr := range rrs {
if rrMap[rr.Header().Rrtype] == nil {
rrMap[rr.Header().Rrtype] = make(map[string][]dns.RR)
}
if rrMap[rr.Header().Rrtype][rr.Header().Name] == nil {
rrMap[rr.Header().Rrtype][rr.Header().Name] = make([]dns.RR, 0)
}
rrMap[rr.Header().Rrtype][rr.Header().Name] = append(rrMap[rr.Header().Rrtype][rr.Header().Name], rr)
}
return rrMap
}
2020-12-26 20:56:51 +00:00
func loadZonefile(filepath, origin string) ([]dns.RR, error) {
2020-12-21 21:43:07 +00:00
file, err := os.Open(filepath)
if err != nil {
return nil, err
}
2020-12-26 20:56:51 +00:00
parser := dns.NewZoneParser(file, origin, filepath)
2020-12-21 21:43:07 +00:00
var rrs = make([]dns.RR, 0)
for rr, ok := parser.Next(); ok; rr, ok = parser.Next() {
rrs = append(rrs, rr)
}
if err := parser.Err(); err != nil {
log.Println(err)
}
return rrs, nil
}
2021-01-06 14:53:58 +00:00
// createACLList create a map with the CIDR and the name of the rule
2020-12-22 18:07:39 +00:00
func createACLList(config []configACL) (map[string]*net.IPNet, error) {
acls := make(map[string]*net.IPNet)
for _, aclRule := range config {
_, mask, err := net.ParseCIDR(aclRule.CIDR)
if err != nil {
return nil, err
}
acls[aclRule.Name] = mask
}
return acls, nil
}
2021-01-06 14:53:58 +00:00
// createServer creates a new serve mux. Adds all the logic to handle the request
2021-01-08 15:08:57 +00:00
func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool, acmeList *legoMap) *dns.ServeMux {
2020-12-21 21:43:07 +00:00
srv := dns.NewServeMux()
2020-12-23 20:44:33 +00:00
c := new(dns.Client)
2020-12-21 21:43:07 +00:00
2021-01-06 14:53:58 +00:00
// For all zones set from the config
2020-12-25 13:43:14 +00:00
for zoneName, zones := range zones {
srv.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) {
2020-12-22 18:07:39 +00:00
2020-12-23 12:10:58 +00:00
// Parse IP
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
2020-12-23 20:58:08 +00:00
ip := net.ParseIP(remoteIP)
if err != nil && ip != nil {
2020-12-26 20:56:51 +00:00
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
2020-12-23 12:10:58 +00:00
return
}
2021-01-08 15:08:57 +00:00
// Check if it is a ACME DNS-01 challange
if handleACMERequest(w, r, acmeList) {
return
}
2020-12-25 13:43:14 +00:00
// find out what view to handle the request
zoneIndex := -1
2020-12-22 18:07:39 +00:00
2020-12-25 13:43:14 +00:00
for i, zone := range zones {
if (len(zone.acl) == 0 && zoneIndex == -1) || checkACL(zone.acl, aclList, ip) {
zoneIndex = i
2020-12-22 21:13:32 +00:00
}
2020-12-25 13:43:14 +00:00
}
2020-12-21 21:43:07 +00:00
2021-01-06 14:53:58 +00:00
// No view found that can handle the request
2020-12-25 13:43:14 +00:00
if zoneIndex == -1 {
rcodeRequest(w, r, dns.RcodeRefused)
return
2020-12-21 21:43:07 +00:00
}
2020-12-25 13:43:14 +00:00
handleRequest(w, r, zones[zoneIndex])
2020-12-21 21:43:07 +00:00
})
}
2021-01-06 14:53:58 +00:00
// Handle any other request for forwarding
2020-12-23 20:44:33 +00:00
srv.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
2021-01-06 14:53:58 +00:00
// Parse IP
2020-12-23 20:44:33 +00:00
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
ip := net.ParseIP(remoteIP)
if err != nil && ip != nil {
2020-12-26 20:56:51 +00:00
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
2020-12-23 20:44:33 +00:00
return
}
2021-01-08 15:08:57 +00:00
// Check if it is a ACME DNS-01 challange
if handleACMERequest(w, r, acmeList) {
return
}
2020-12-23 20:44:33 +00:00
// Check ACL rules
2020-12-23 20:58:08 +00:00
if !checkACL(config.Forward.ACL, aclList, ip) {
rcodeRequest(w, r, dns.RcodeRefused)
return
2020-12-23 20:44:33 +00:00
}
2021-01-06 14:53:58 +00:00
// Check if the domain is bocked
2020-12-27 21:13:13 +00:00
if _, ok := blacklist[r.Question[0].Name]; ok {
2020-12-27 23:36:54 +00:00
handleBlockedDomain(w, r)
2020-12-27 21:13:13 +00:00
} else {
// Forward request
in, _, err := c.Exchange(r, config.Forward.Server)
if err != nil {
rcodeRequest(w, r, dns.RcodeServerFailure)
return
}
2021-01-06 14:53:58 +00:00
2020-12-27 21:13:13 +00:00
w.WriteMsg(in)
2020-12-23 20:44:33 +00:00
}
})
2020-12-21 21:43:07 +00:00
return srv
}
2020-12-26 13:41:08 +00:00
func listenAndServer(server *dns.ServeMux, address string) {
2021-01-06 14:53:58 +00:00
// Start UDP listner
2020-12-21 21:43:07 +00:00
go func() {
2020-12-26 13:41:08 +00:00
if err := dns.ListenAndServe(address, "udp", server); err != nil {
2020-12-21 21:43:07 +00:00
log.Fatalf("Failed to set udp listener %s\n", err.Error())
}
}()
2021-01-06 14:53:58 +00:00
// Start TCP listner
2020-12-21 21:43:07 +00:00
go func() {
2020-12-26 13:41:08 +00:00
if err := dns.ListenAndServe(address, "tcp", server); err != nil {
2020-12-21 21:43:07 +00:00
log.Fatalf("Failed to set tcp listener %s\n", err.Error())
}
}()
}
2020-12-30 20:59:33 +00:00
func listenAndServerTLS(server *dns.ServeMux, address, cert, key string) {
2021-01-06 14:53:58 +00:00
// Start TLS listner
2020-12-30 20:59:33 +00:00
go func() {
if err := dns.ListenAndServeTLS(address, cert, key, server); err != nil {
log.Fatalf("Failed to set DoT listener %s", err.Error())
}
}()
}
2020-12-23 20:58:08 +00:00
func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool {
if len(alcRules) != 0 {
passed := false
for _, rule := range alcRules {
if aclList[rule].Contains(ip) {
passed = true
}
}
return passed
}
return true
}
2021-01-06 14:53:58 +00:00
// rcodeRequest respond to a request with a response code
2020-12-23 20:58:08 +00:00
func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) {
m := new(dns.Msg)
m.SetReply(r)
m.SetRcode(r, rcode)
w.WriteMsg(m)
}
2021-01-06 14:53:58 +00:00
// handleRequest find the right RR(s) in the view and send them back
2020-12-25 13:43:14 +00:00
func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
2020-12-31 13:11:07 +00:00
// Only support one question per query because all the other server also does that
if len(r.Question) != 1 {
rcodeRequest(w, r, dns.RcodeServerFailure)
}
q := r.Question[0]
rrs := zone.rr[q.Qtype]
// Handle ANY
if q.Qtype == dns.TypeANY {
for _, rrType := range anyRecordTypes {
m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...)
}
} else {
// Handle any other type
m.Answer = append(m.Answer, rrs[q.Name]...)
// Check for wildcard
if len(m.Answer) == 0 {
parts := dns.SplitDomainName(q.Name)[1:]
searchDomain := "*." + dns.Fqdn(strings.Join(parts, "."))
foundDomain := rrs[searchDomain]
for _, rr := range foundDomain {
newRR := rr
newRR.Header().Name = q.Name
m.Answer = append(m.Answer, newRR)
2020-12-25 13:43:14 +00:00
}
}
2020-12-31 13:11:07 +00:00
}
2020-12-25 13:43:14 +00:00
2020-12-31 13:11:07 +00:00
// Handle extras
switch q.Qtype {
// Dont handle extra stuff when answering ANY request
// case dns.TypeANY:
// fallthrough
case dns.TypeMX:
// Resolve MX domains
for _, mxRR := range m.Answer {
if t, ok := mxRR.(*dns.MX); ok {
m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...)
m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...)
2020-12-25 13:43:14 +00:00
}
2020-12-31 13:11:07 +00:00
}
case dns.TypeA, dns.TypeAAAA:
if len(m.Answer) == 0 {
// no A or AAAA found. Look for CNAME
m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...)
if len(m.Answer) != 0 {
// Resolve CNAME
for _, nameRR := range m.Answer {
if t, ok := nameRR.(*dns.CNAME); ok {
m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...)
2020-12-25 13:43:14 +00:00
}
}
}
2020-12-31 13:11:07 +00:00
}
case dns.TypeNS:
// Resove NS records
for _, nsRR := range m.Answer {
if t, ok := nsRR.(*dns.NS); ok {
m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...)
m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...)
2020-12-25 13:43:14 +00:00
}
}
}
2021-01-06 22:06:47 +00:00
if len(m.Answer) == 0 {
m.SetRcode(m, dns.RcodeNameError)
}
2020-12-25 13:43:14 +00:00
w.WriteMsg(m)
}
2020-12-21 21:43:07 +00:00
func main() {
2020-12-26 20:31:50 +00:00
configPath := flag.String("c", "/etc/cool-dns/config.yaml", "path to the config file")
flag.Parse()
2021-01-08 20:26:28 +00:00
config, err := loadConfig(*configPath)
2021-01-08 18:12:58 +00:00
if err != nil {
2021-01-08 20:26:28 +00:00
log.Fatalf("Failed to load config: %s\n", err.Error())
2021-01-08 18:12:58 +00:00
}
2021-01-08 20:26:28 +00:00
err = os.Chdir(filepath.Dir(*configPath))
2020-12-21 21:43:07 +00:00
if err != nil {
2021-01-08 20:26:28 +00:00
log.Fatalf("Failed to goto config dir: %s", err.Error())
2020-12-21 21:43:07 +00:00
}
zones, err := loadZones(config.Zones)
if err != nil {
2020-12-26 20:34:11 +00:00
log.Fatalf("Failed to load zones: %s\n", err.Error())
2020-12-21 21:43:07 +00:00
}
2020-12-22 18:07:39 +00:00
aclList, err := createACLList(config.ACL)
if err != nil {
2020-12-26 20:34:11 +00:00
log.Fatalf("Failed to parse ACL rules: %s\n", err.Error())
2020-12-22 18:07:39 +00:00
}
2020-12-27 21:13:13 +00:00
blacklist := loadBlacklist(config.Blacklist)
2021-01-08 15:08:57 +00:00
var acmeMap *legoMap
if config.Lego.Enable {
acmeMap = startLEGOWebSever(config.Lego)
}
server := createServer(zones, *config, aclList, blacklist, acmeMap)
2020-12-21 21:43:07 +00:00
2020-12-26 13:41:08 +00:00
listenAndServer(server, config.Address)
2020-12-26 20:34:11 +00:00
2020-12-30 20:59:33 +00:00
if config.TLS.Enable {
listenAndServerTLS(server, config.TLS.Address, config.TLS.Cert, config.TLS.Key)
log.Printf("Start listening on tcp %s for tls", config.TLS.Address)
}
2020-12-26 20:34:11 +00:00
log.Printf("Start listening on udp %s and tcp %s\n", config.Address, config.Address)
sig := make(chan os.Signal)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
s := <-sig
log.Printf("Signal (%v) received, stopping\n", s)
os.Exit(0)
2020-12-21 21:43:07 +00:00
}