Compare commits
17 Commits
9adc685a73
...
develop
| Author | SHA1 | Date | |
|---|---|---|---|
| afdea7e64f | |||
| 4394d6078d | |||
| ebce605863 | |||
| 7255fc02c8 | |||
| 00a82aca87 | |||
| 8d366a9833 | |||
| 07271dea71 | |||
| 1709b2099a | |||
| 8f499d8f85 | |||
| 0fe5d73853 | |||
| f0f4fa5376 | |||
| f497c96ad1 | |||
| 39ca792d74 | |||
| 30a0b7c5df | |||
| f5c2376b36 | |||
| d56b459d9a | |||
| 0336002980 |
10
.drone.yml
10
.drone.yml
@@ -5,7 +5,7 @@ steps:
|
||||
- name: build
|
||||
image: golang
|
||||
commands:
|
||||
- go build
|
||||
- go build cmd/cooldns.go
|
||||
|
||||
- name: gitea_release
|
||||
image: plugins/gitea-release
|
||||
@@ -14,7 +14,7 @@ steps:
|
||||
from_secret: GITEA_API_KEY
|
||||
base_url: https://git.kapelle.org
|
||||
files:
|
||||
- cool-dns
|
||||
- cooldns
|
||||
checksum:
|
||||
- md5
|
||||
- sha1
|
||||
@@ -25,6 +25,12 @@ steps:
|
||||
event:
|
||||
- tag
|
||||
|
||||
- name: docker
|
||||
image: plugins/docker
|
||||
settings:
|
||||
repo: docker.kapelle.org/cooldns
|
||||
registry: docker.kapelle.org
|
||||
|
||||
trigger:
|
||||
event:
|
||||
- tag
|
||||
|
||||
12
Dockerfile
12
Dockerfile
@@ -2,15 +2,15 @@ FROM golang:alpine AS build
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY go.mod .
|
||||
COPY go.sum .
|
||||
COPY *.go ./
|
||||
COPY [ "go.mod", "go.sum", "./"]
|
||||
COPY internal ./internal
|
||||
COPY cmd ./cmd
|
||||
|
||||
RUN go build -o cool-dns .
|
||||
RUN go build -o cooldns cmd/cooldns.go
|
||||
|
||||
FROM alpine:latest
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=build /build/cool-dns .
|
||||
COPY --from=build /build/cooldns .
|
||||
|
||||
ENTRYPOINT ["/app/cool-dns"]
|
||||
ENTRYPOINT ["/app/cooldns"]
|
||||
|
||||
@@ -42,4 +42,10 @@ blacklist: # What domains to block when forwarding
|
||||
format: host # Format of the blacklist: Hostfile
|
||||
- url: https://blocklistproject.github.io/Lists/alt-version/ads-nl.txt
|
||||
format: line # Format: One domain per line
|
||||
|
||||
lego: # Support for Lego http provider. See https://go-acme.github.io/lego/dns/httpreq/
|
||||
enable: true
|
||||
address: :8080
|
||||
username: lego
|
||||
secret: "133742069ab"
|
||||
```
|
||||
24
cmd/cooldns.go
Normal file
24
cmd/cooldns.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
cooldns "git.kapelle.org/niklas/cool-dns/internal"
|
||||
)
|
||||
|
||||
func main() {
|
||||
configPath := flag.String("c", "/etc/cool-dns/config.yaml", "path to the config file")
|
||||
flag.Parse()
|
||||
|
||||
cooldns.Start(*configPath)
|
||||
|
||||
sig := make(chan os.Signal)
|
||||
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||
s := <-sig
|
||||
log.Printf("Signal (%v) received, stopping\n", s)
|
||||
os.Exit(0)
|
||||
}
|
||||
435
coolDns.go
435
coolDns.go
@@ -1,435 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type zoneView struct {
|
||||
rr rrMap
|
||||
acl []string
|
||||
}
|
||||
|
||||
type zoneMap map[string][]zoneView
|
||||
|
||||
type rrMap map[uint16]map[string][]dns.RR
|
||||
|
||||
// config format of the config file
|
||||
type config struct {
|
||||
Zones []configZone `yaml:"zones"`
|
||||
ACL []configACL `yaml:"acl"`
|
||||
Forward configForward `yaml:"forward"`
|
||||
Address string `yaml:"address"`
|
||||
Blacklist []configBlacklist `yaml:"blacklist"`
|
||||
TLS configTLS `yaml:"tls"`
|
||||
Lego configLego `yaml:"lego"`
|
||||
}
|
||||
|
||||
type configForward struct {
|
||||
ACL []string `yaml:"acl"`
|
||||
Server string `yaml:"server"`
|
||||
}
|
||||
|
||||
type configACL struct {
|
||||
Name string `yaml:"name"`
|
||||
CIDR string `yaml:"cidr"`
|
||||
}
|
||||
|
||||
type configZone struct {
|
||||
Zone string `yaml:"zone"`
|
||||
File string `yaml:"file"`
|
||||
ACL []string `yaml:"acl"`
|
||||
}
|
||||
|
||||
type configBlacklist struct {
|
||||
URL string `yaml:"url"`
|
||||
Format string `yaml:"format"`
|
||||
}
|
||||
|
||||
type configTLS struct {
|
||||
Enable bool `yaml:"enable"`
|
||||
Address string `yaml:"address"`
|
||||
Cert string `yaml:"cert"`
|
||||
Key string `yaml:"key"`
|
||||
}
|
||||
|
||||
// All record types to send when a ANY request is send
|
||||
var anyRecordTypes = []uint16{
|
||||
dns.TypeSOA,
|
||||
dns.TypeA,
|
||||
dns.TypeAAAA,
|
||||
dns.TypeNS,
|
||||
dns.TypeCNAME,
|
||||
dns.TypeMX,
|
||||
dns.TypeTXT,
|
||||
dns.TypeSRV,
|
||||
dns.TypeCAA,
|
||||
}
|
||||
|
||||
func start(configPath string) {
|
||||
config, err := loadConfig(configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %s\n", err.Error())
|
||||
}
|
||||
|
||||
err = os.Chdir(filepath.Dir(configPath))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to goto config dir: %s", err.Error())
|
||||
}
|
||||
|
||||
zones, err := loadZones(config.Zones)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load zones: %s\n", err.Error())
|
||||
}
|
||||
|
||||
aclList, err := createACLList(config.ACL)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse ACL rules: %s\n", err.Error())
|
||||
}
|
||||
|
||||
blacklist := loadBlacklist(config.Blacklist)
|
||||
|
||||
var acmeMap *legoMap
|
||||
if config.Lego.Enable {
|
||||
acmeMap = startLEGOWebSever(config.Lego)
|
||||
}
|
||||
|
||||
server := createServer(zones, *config, aclList, blacklist, acmeMap)
|
||||
|
||||
listenAndServer(server, config.Address)
|
||||
|
||||
if config.TLS.Enable {
|
||||
listenAndServerTLS(server, config.TLS.Address, config.TLS.Cert, config.TLS.Key)
|
||||
|
||||
log.Printf("Start listening on tcp %s for tls", config.TLS.Address)
|
||||
}
|
||||
|
||||
log.Printf("Start listening on udp %s and tcp %s\n", config.Address, config.Address)
|
||||
}
|
||||
|
||||
func loadConfig(configPath string) (*config, error) {
|
||||
file, err := ioutil.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var loadedConfig config
|
||||
|
||||
err = yaml.Unmarshal(file, &loadedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &loadedConfig, nil
|
||||
}
|
||||
|
||||
func loadZones(configZones []configZone) (zoneMap, error) {
|
||||
zones := make(zoneMap)
|
||||
for _, z := range configZones {
|
||||
rrs, err := loadZonefile(z.File, z.Zone)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if zones[z.Zone] == nil {
|
||||
zones[z.Zone] = make([]zoneView, 0)
|
||||
}
|
||||
zones[z.Zone] = append(zones[z.Zone], zoneView{
|
||||
rr: createRRMap(rrs),
|
||||
acl: z.ACL,
|
||||
})
|
||||
|
||||
log.Printf("Loaded zone %s\n", z.Zone)
|
||||
}
|
||||
|
||||
return zones, nil
|
||||
}
|
||||
|
||||
// createRRMap order the rr into a structure that is more easy to use
|
||||
func createRRMap(rrs []dns.RR) rrMap {
|
||||
rrMap := make(rrMap)
|
||||
for _, rr := range rrs {
|
||||
if rrMap[rr.Header().Rrtype] == nil {
|
||||
rrMap[rr.Header().Rrtype] = make(map[string][]dns.RR)
|
||||
}
|
||||
|
||||
if rrMap[rr.Header().Rrtype][rr.Header().Name] == nil {
|
||||
rrMap[rr.Header().Rrtype][rr.Header().Name] = make([]dns.RR, 0)
|
||||
}
|
||||
rrMap[rr.Header().Rrtype][rr.Header().Name] = append(rrMap[rr.Header().Rrtype][rr.Header().Name], rr)
|
||||
}
|
||||
|
||||
return rrMap
|
||||
}
|
||||
|
||||
func loadZonefile(filepath, origin string) ([]dns.RR, error) {
|
||||
file, err := os.Open(filepath)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parser := dns.NewZoneParser(file, origin, filepath)
|
||||
|
||||
var rrs = make([]dns.RR, 0)
|
||||
|
||||
for rr, ok := parser.Next(); ok; rr, ok = parser.Next() {
|
||||
rrs = append(rrs, rr)
|
||||
}
|
||||
|
||||
if err := parser.Err(); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
||||
return rrs, nil
|
||||
}
|
||||
|
||||
// createACLList create a map with the CIDR and the name of the rule
|
||||
func createACLList(config []configACL) (map[string]*net.IPNet, error) {
|
||||
acls := make(map[string]*net.IPNet)
|
||||
|
||||
for _, aclRule := range config {
|
||||
_, mask, err := net.ParseCIDR(aclRule.CIDR)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
acls[aclRule.Name] = mask
|
||||
}
|
||||
|
||||
return acls, nil
|
||||
}
|
||||
|
||||
// createServer creates a new serve mux. Adds all the logic to handle the request
|
||||
func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool, acmeList *legoMap) *dns.ServeMux {
|
||||
srv := dns.NewServeMux()
|
||||
c := new(dns.Client)
|
||||
|
||||
// For all zones set from the config
|
||||
for zoneName, zones := range zones {
|
||||
srv.HandleFunc(zoneName, func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// Parse IP
|
||||
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
|
||||
ip := net.ParseIP(remoteIP)
|
||||
if err != nil && ip != nil {
|
||||
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it is a ACME DNS-01 challange
|
||||
if handleACMERequest(w, r, acmeList) {
|
||||
return
|
||||
}
|
||||
|
||||
// find out what view to handle the request
|
||||
zoneIndex := -1
|
||||
|
||||
for i, zone := range zones {
|
||||
if (len(zone.acl) == 0 && zoneIndex == -1) || checkACL(zone.acl, aclList, ip) {
|
||||
zoneIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
// No view found that can handle the request
|
||||
if zoneIndex == -1 {
|
||||
rcodeRequest(w, r, dns.RcodeRefused)
|
||||
return
|
||||
}
|
||||
|
||||
handleRequest(w, r, zones[zoneIndex])
|
||||
})
|
||||
}
|
||||
|
||||
// Handle any other request for forwarding
|
||||
srv.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// Parse IP
|
||||
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
|
||||
ip := net.ParseIP(remoteIP)
|
||||
|
||||
if err != nil && ip != nil {
|
||||
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it is a ACME DNS-01 challange
|
||||
if handleACMERequest(w, r, acmeList) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check ACL rules
|
||||
if !checkACL(config.Forward.ACL, aclList, ip) {
|
||||
rcodeRequest(w, r, dns.RcodeRefused)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the domain is bocked
|
||||
if _, ok := blacklist[r.Question[0].Name]; ok {
|
||||
handleBlockedDomain(w, r)
|
||||
} else {
|
||||
// Forward request
|
||||
in, _, err := c.Exchange(r, config.Forward.Server)
|
||||
|
||||
if err != nil {
|
||||
rcodeRequest(w, r, dns.RcodeServerFailure)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteMsg(in)
|
||||
}
|
||||
})
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func listenAndServer(server *dns.ServeMux, address string) {
|
||||
// Start UDP listner
|
||||
go func() {
|
||||
if err := dns.ListenAndServe(address, "udp", server); err != nil {
|
||||
log.Fatalf("Failed to set udp listener %s\n", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
// Start TCP listner
|
||||
go func() {
|
||||
if err := dns.ListenAndServe(address, "tcp", server); err != nil {
|
||||
log.Fatalf("Failed to set tcp listener %s\n", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func listenAndServerTLS(server *dns.ServeMux, address, cert, key string) {
|
||||
// Start TLS listner
|
||||
go func() {
|
||||
if err := dns.ListenAndServeTLS(address, cert, key, server); err != nil {
|
||||
log.Fatalf("Failed to set DoT listener %s", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool {
|
||||
if len(alcRules) != 0 {
|
||||
passed := false
|
||||
for _, rule := range alcRules {
|
||||
|
||||
if aclList[rule].Contains(ip) {
|
||||
passed = true
|
||||
}
|
||||
}
|
||||
return passed
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// rcodeRequest respond to a request with a response code
|
||||
func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.SetRcode(r, rcode)
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
// handleRequest find the right RR(s) in the view and send them back
|
||||
func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Authoritative = true
|
||||
|
||||
// Only support one question per query because all the other server also does that
|
||||
if len(r.Question) != 1 {
|
||||
rcodeRequest(w, r, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
q := r.Question[0]
|
||||
|
||||
rrs := zone.rr[q.Qtype]
|
||||
|
||||
// Handle ANY
|
||||
if q.Qtype == dns.TypeANY {
|
||||
for _, rrType := range anyRecordTypes {
|
||||
m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...)
|
||||
}
|
||||
} else {
|
||||
// Handle any other type
|
||||
m.Answer = append(m.Answer, rrs[q.Name]...)
|
||||
|
||||
// Check for wildcard
|
||||
if len(m.Answer) == 0 {
|
||||
parts := dns.SplitDomainName(q.Name)[1:]
|
||||
searchDomain := "*." + dns.Fqdn(strings.Join(parts, "."))
|
||||
foundDomain := rrs[searchDomain]
|
||||
for _, rr := range foundDomain {
|
||||
newRR := rr
|
||||
newRR.Header().Name = q.Name
|
||||
m.Answer = append(m.Answer, newRR)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle extras
|
||||
switch q.Qtype {
|
||||
// Dont handle extra stuff when answering ANY request
|
||||
// case dns.TypeANY:
|
||||
// fallthrough
|
||||
case dns.TypeMX:
|
||||
// Resolve MX domains
|
||||
for _, mxRR := range m.Answer {
|
||||
if t, ok := mxRR.(*dns.MX); ok {
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...)
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...)
|
||||
}
|
||||
}
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
if len(m.Answer) == 0 {
|
||||
// no A or AAAA found. Look for CNAME
|
||||
m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...)
|
||||
if len(m.Answer) != 0 {
|
||||
// Resolve CNAME
|
||||
for _, nameRR := range m.Answer {
|
||||
if t, ok := nameRR.(*dns.CNAME); ok {
|
||||
m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case dns.TypeNS:
|
||||
// Resove NS records
|
||||
for _, nsRR := range m.Answer {
|
||||
if t, ok := nsRR.(*dns.NS); ok {
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...)
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.Answer) == 0 {
|
||||
m.SetRcode(m, dns.RcodeNameError)
|
||||
}
|
||||
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
configPath := flag.String("c", "/etc/cool-dns/config.yaml", "path to the config file")
|
||||
flag.Parse()
|
||||
|
||||
start(*configPath)
|
||||
|
||||
sig := make(chan os.Signal)
|
||||
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||
s := <-sig
|
||||
log.Printf("Signal (%v) received, stopping\n", s)
|
||||
os.Exit(0)
|
||||
}
|
||||
@@ -5,7 +5,7 @@ zones:
|
||||
- zone: example.com.
|
||||
file: zonefile2.txt
|
||||
acl:
|
||||
- lan
|
||||
- vpn
|
||||
|
||||
acl:
|
||||
- name: vpn
|
||||
@@ -23,7 +23,7 @@ forward:
|
||||
address: 0.0.0.0:8053
|
||||
|
||||
tls:
|
||||
enable: true
|
||||
enable: false
|
||||
address: 0.0.0.0:8853
|
||||
cert: cert.crt
|
||||
key: private.key
|
||||
136
internal/authoritative.go
Normal file
136
internal/authoritative.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package cooldns
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type rrMap map[uint16]map[string][]dns.RR
|
||||
|
||||
// All record types to send when a ANY request is send
|
||||
var anyRecordTypes = []uint16{
|
||||
dns.TypeSOA,
|
||||
dns.TypeA,
|
||||
dns.TypeAAAA,
|
||||
dns.TypeNS,
|
||||
dns.TypeCNAME,
|
||||
dns.TypeMX,
|
||||
dns.TypeTXT,
|
||||
dns.TypeSRV,
|
||||
dns.TypeCAA,
|
||||
}
|
||||
|
||||
// handleRequest find the right RR(s) in the view and send them back
|
||||
func handleRequest(w dns.ResponseWriter, r *dns.Msg, zone zoneView) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Authoritative = true
|
||||
|
||||
// Only support one question per query because all the other server also does that
|
||||
if len(r.Question) != 1 {
|
||||
rcodeRequest(w, r, dns.RcodeServerFailure)
|
||||
}
|
||||
|
||||
q := r.Question[0]
|
||||
|
||||
rrs := zone.rr[q.Qtype]
|
||||
|
||||
// Handle ANY
|
||||
if q.Qtype == dns.TypeANY {
|
||||
for _, rrType := range anyRecordTypes {
|
||||
m.Answer = append(m.Answer, zone.rr[rrType][q.Name]...)
|
||||
}
|
||||
} else {
|
||||
// Handle any other type
|
||||
m.Answer = append(m.Answer, rrs[q.Name]...)
|
||||
|
||||
// if no rr found yet
|
||||
if len(m.Answer) == 0 {
|
||||
// Check for wildcard
|
||||
parts := dns.SplitDomainName(q.Name)[1:]
|
||||
searchDomain := "*." + dns.Fqdn(strings.Join(parts, "."))
|
||||
foundDomain := rrs[searchDomain]
|
||||
for _, rr := range foundDomain {
|
||||
newRR := rr
|
||||
newRR.Header().Name = q.Name
|
||||
m.Answer = append(m.Answer, newRR)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle extras
|
||||
switch q.Qtype {
|
||||
// Dont handle extra stuff when answering ANY request
|
||||
// case dns.TypeANY:
|
||||
// fallthrough
|
||||
case dns.TypeMX:
|
||||
// Resolve MX domains
|
||||
for _, mxRR := range m.Answer {
|
||||
if t, ok := mxRR.(*dns.MX); ok {
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Mx]...)
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Mx]...)
|
||||
}
|
||||
}
|
||||
case dns.TypeA, dns.TypeAAAA:
|
||||
if len(m.Answer) == 0 {
|
||||
// no A or AAAA found. Look for CNAME
|
||||
m.Answer = append(m.Answer, zone.rr[dns.TypeCNAME][q.Name]...)
|
||||
if len(m.Answer) != 0 {
|
||||
// Resolve CNAME
|
||||
for _, nameRR := range m.Answer {
|
||||
if t, ok := nameRR.(*dns.CNAME); ok {
|
||||
m.Answer = append(m.Answer, zone.rr[q.Qtype][t.Target]...)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No direct A/AAAA or CNAME found. Check for CNAME wildcard
|
||||
parts := dns.SplitDomainName(q.Name)[1:]
|
||||
searchDomain := "*." + dns.Fqdn(strings.Join(parts, "."))
|
||||
foundDomain := zone.rr[dns.TypeCNAME][searchDomain]
|
||||
for _, rr := range foundDomain {
|
||||
// Add CNAME to answer section
|
||||
newRR := rr
|
||||
newRR.Header().Name = q.Name
|
||||
m.Answer = append(m.Answer, newRR)
|
||||
|
||||
// Add resolved CNAME to *also* to the answer section (bind does the same soo)
|
||||
if t, ok := rr.(*dns.CNAME); ok {
|
||||
m.Answer = append(m.Answer, zone.rr[dns.TypeA][t.Target]...)
|
||||
m.Answer = append(m.Answer, zone.rr[dns.TypeAAAA][t.Target]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case dns.TypeNS:
|
||||
// Resove NS records
|
||||
for _, nsRR := range m.Answer {
|
||||
if t, ok := nsRR.(*dns.NS); ok {
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Ns]...)
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Ns]...)
|
||||
}
|
||||
}
|
||||
case dns.TypeCNAME:
|
||||
// Resolve CNAME
|
||||
for _, cnameRR := range m.Answer {
|
||||
if t, ok := cnameRR.(*dns.CNAME); ok {
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeA][t.Target]...)
|
||||
m.Extra = append(m.Extra, zone.rr[dns.TypeAAAA][t.Target]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.Answer) == 0 {
|
||||
var soa dns.RR
|
||||
for _, v := range zone.rr[dns.TypeSOA] {
|
||||
if len(v) == 1 {
|
||||
soa = v[0]
|
||||
}
|
||||
}
|
||||
if soa != nil {
|
||||
m.Extra = append(m.Extra, soa)
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package cooldns
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
const blockTTL uint32 = 300
|
||||
const blockTTL uint32 = 604800
|
||||
|
||||
var nullIPv4 = net.IPv4(0, 0, 0, 0)
|
||||
var nullIPv6 = net.ParseIP("::/0")
|
||||
@@ -86,12 +86,10 @@ func parseRawBlacklist(blacklist configBlacklist, raw string) []string {
|
||||
// parseHostFormat parse the string in the format of a hostfile
|
||||
func parseHostFormat(raw string) []string {
|
||||
finalList := make([]string, 0)
|
||||
reg := regexp.MustCompile(`(?mi)^\s*(#*)\s*(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})\s+([a-zA-Z0-9\.\- ]+)$`)
|
||||
reg := regexp.MustCompile(`(?m)^\s*(0\.0\.0\.0) ([a-zA-Z0-9-.]*)`)
|
||||
matches := reg.FindAllStringSubmatch(raw, -1)
|
||||
for _, match := range matches {
|
||||
if match[1] != "#" {
|
||||
finalList = append(finalList, dns.Fqdn(match[3]))
|
||||
}
|
||||
finalList = append(finalList, dns.Fqdn(match[2]))
|
||||
}
|
||||
|
||||
return finalList
|
||||
144
internal/config.go
Normal file
144
internal/config.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package cooldns
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// config format of the config file
|
||||
type config struct {
|
||||
Zones []configZone `yaml:"zones"`
|
||||
ACL []configACL `yaml:"acl"`
|
||||
Forward configForward `yaml:"forward"`
|
||||
Address string `yaml:"address"`
|
||||
Blacklist []configBlacklist `yaml:"blacklist"`
|
||||
TLS configTLS `yaml:"tls"`
|
||||
Lego configLego `yaml:"lego"`
|
||||
}
|
||||
|
||||
type configForward struct {
|
||||
Enable bool `yaml:"enable"`
|
||||
ACL []string `yaml:"acl"`
|
||||
Server string `yaml:"server"`
|
||||
}
|
||||
|
||||
type configACL struct {
|
||||
Name string `yaml:"name"`
|
||||
CIDR string `yaml:"cidr"`
|
||||
}
|
||||
|
||||
type configZone struct {
|
||||
Zone string `yaml:"zone"`
|
||||
File string `yaml:"file"`
|
||||
ACL []string `yaml:"acl"`
|
||||
}
|
||||
|
||||
type configBlacklist struct {
|
||||
URL string `yaml:"url"`
|
||||
Format string `yaml:"format"`
|
||||
}
|
||||
|
||||
type configTLS struct {
|
||||
Enable bool `yaml:"enable"`
|
||||
Address string `yaml:"address"`
|
||||
Cert string `yaml:"cert"`
|
||||
Key string `yaml:"key"`
|
||||
}
|
||||
|
||||
func loadConfig(configPath string) (*config, error) {
|
||||
file, err := ioutil.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var loadedConfig config
|
||||
|
||||
err = yaml.Unmarshal(file, &loadedConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &loadedConfig, nil
|
||||
}
|
||||
|
||||
func loadZones(configZones []configZone) (zoneMap, error) {
|
||||
zones := make(zoneMap)
|
||||
for _, z := range configZones {
|
||||
rrs, err := loadZonefile(z.File, z.Zone)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if zones[z.Zone] == nil {
|
||||
zones[z.Zone] = make([]zoneView, 0)
|
||||
}
|
||||
zones[z.Zone] = append(zones[z.Zone], zoneView{
|
||||
rr: createRRMap(rrs),
|
||||
acl: z.ACL,
|
||||
})
|
||||
|
||||
log.Printf("Loaded zone %s\n", z.Zone)
|
||||
}
|
||||
|
||||
return zones, nil
|
||||
}
|
||||
|
||||
// createRRMap order the rr into a structure that is more easy to use
|
||||
func createRRMap(rrs []dns.RR) rrMap {
|
||||
rrMap := make(rrMap)
|
||||
for _, rr := range rrs {
|
||||
if rrMap[rr.Header().Rrtype] == nil {
|
||||
rrMap[rr.Header().Rrtype] = make(map[string][]dns.RR)
|
||||
}
|
||||
|
||||
if rrMap[rr.Header().Rrtype][rr.Header().Name] == nil {
|
||||
rrMap[rr.Header().Rrtype][rr.Header().Name] = make([]dns.RR, 0)
|
||||
}
|
||||
rrMap[rr.Header().Rrtype][rr.Header().Name] = append(rrMap[rr.Header().Rrtype][rr.Header().Name], rr)
|
||||
}
|
||||
|
||||
return rrMap
|
||||
}
|
||||
|
||||
func loadZonefile(filepath, origin string) ([]dns.RR, error) {
|
||||
file, err := os.Open(filepath)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
parser := dns.NewZoneParser(file, origin, filepath)
|
||||
|
||||
var rrs = make([]dns.RR, 0)
|
||||
|
||||
for rr, ok := parser.Next(); ok; rr, ok = parser.Next() {
|
||||
rrs = append(rrs, rr)
|
||||
}
|
||||
|
||||
if err := parser.Err(); err != nil {
|
||||
log.Println(err)
|
||||
}
|
||||
|
||||
return rrs, nil
|
||||
}
|
||||
|
||||
// createACLList create a map with the CIDR and the name of the rule
|
||||
func createACLList(config []configACL) (map[string]*net.IPNet, error) {
|
||||
acls := make(map[string]*net.IPNet)
|
||||
|
||||
for _, aclRule := range config {
|
||||
_, mask, err := net.ParseCIDR(aclRule.CIDR)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
acls[aclRule.Name] = mask
|
||||
}
|
||||
|
||||
return acls, nil
|
||||
}
|
||||
192
internal/cooldns.go
Normal file
192
internal/cooldns.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package cooldns
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type zoneView struct {
|
||||
rr rrMap
|
||||
acl []string
|
||||
}
|
||||
|
||||
type zoneMap map[string][]zoneView
|
||||
|
||||
// Start starts cooldns
|
||||
func Start(configPath string) {
|
||||
config, err := loadConfig(configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load config: %s\n", err.Error())
|
||||
}
|
||||
|
||||
err = os.Chdir(filepath.Dir(configPath))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to goto config dir: %s", err.Error())
|
||||
}
|
||||
|
||||
zones, err := loadZones(config.Zones)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load zones: %s\n", err.Error())
|
||||
}
|
||||
|
||||
aclList, err := createACLList(config.ACL)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse ACL rules: %s\n", err.Error())
|
||||
}
|
||||
|
||||
blacklist := loadBlacklist(config.Blacklist)
|
||||
|
||||
var acmeMap *legoMap
|
||||
if config.Lego.Enable {
|
||||
acmeMap = startLEGOWebSever(config.Lego)
|
||||
}
|
||||
|
||||
server := createServer(zones, *config, aclList, blacklist, acmeMap)
|
||||
|
||||
listenAndServer(server, config.Address)
|
||||
|
||||
if config.TLS.Enable {
|
||||
listenAndServerTLS(server, config.TLS.Address, config.TLS.Cert, config.TLS.Key)
|
||||
|
||||
log.Printf("Start listening on tcp %s for tls", config.TLS.Address)
|
||||
}
|
||||
|
||||
log.Printf("Start listening on udp %s and tcp %s\n", config.Address, config.Address)
|
||||
}
|
||||
|
||||
// createServer creates a new serve mux. Adds all the logic to handle the request
|
||||
func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool, acmeList *legoMap) *dns.ServeMux {
|
||||
srv := dns.NewServeMux()
|
||||
c := new(dns.Client)
|
||||
|
||||
// For all zones set from the config
|
||||
for zoneName, zones := range zones {
|
||||
srv.HandleFunc(zoneName, createHandler(zones, config, aclList, acmeList))
|
||||
}
|
||||
|
||||
// Handle any other request for forwarding
|
||||
srv.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// Parse IP
|
||||
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
|
||||
ip := net.ParseIP(remoteIP)
|
||||
|
||||
if err != nil && ip != nil {
|
||||
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it is a ACME DNS-01 challange
|
||||
if config.Lego.Enable && handleACMERequest(w, r, acmeList) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check ACL rules
|
||||
if config.Forward.Enable && !checkACL(config.Forward.ACL, aclList, ip) {
|
||||
rcodeRequest(w, r, dns.RcodeRefused)
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the domain is bocked
|
||||
if _, ok := blacklist[r.Question[0].Name]; ok {
|
||||
handleBlockedDomain(w, r)
|
||||
} else {
|
||||
// Forward request
|
||||
in, _, err := c.Exchange(r, config.Forward.Server)
|
||||
|
||||
if err != nil {
|
||||
rcodeRequest(w, r, dns.RcodeServerFailure)
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteMsg(in)
|
||||
}
|
||||
})
|
||||
|
||||
return srv
|
||||
}
|
||||
|
||||
func listenAndServer(server *dns.ServeMux, address string) {
|
||||
// Start UDP listner
|
||||
go func() {
|
||||
if err := dns.ListenAndServe(address, "udp", server); err != nil {
|
||||
log.Fatalf("Failed to set udp listener %s\n", err.Error())
|
||||
}
|
||||
}()
|
||||
|
||||
// Start TCP listner
|
||||
go func() {
|
||||
if err := dns.ListenAndServe(address, "tcp", server); err != nil {
|
||||
log.Fatalf("Failed to set tcp listener %s\n", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func listenAndServerTLS(server *dns.ServeMux, address, cert, key string) {
|
||||
// Start TLS listner
|
||||
go func() {
|
||||
if err := dns.ListenAndServeTLS(address, cert, key, server); err != nil {
|
||||
log.Fatalf("Failed to set DoT listener %s", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func checkACL(alcRules []string, aclList map[string]*net.IPNet, ip net.IP) bool {
|
||||
if len(alcRules) != 0 {
|
||||
passed := false
|
||||
for _, rule := range alcRules {
|
||||
|
||||
if aclList[rule].Contains(ip) {
|
||||
passed = true
|
||||
}
|
||||
}
|
||||
return passed
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// rcodeRequest respond to a request with a response code
|
||||
func rcodeRequest(w dns.ResponseWriter, r *dns.Msg, rcode int) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.SetRcode(r, rcode)
|
||||
w.WriteMsg(m)
|
||||
}
|
||||
|
||||
func createHandler(zones []zoneView, config config, aclList map[string]*net.IPNet, acmeList *legoMap) func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
return func(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// Parse IP
|
||||
remoteIP, _, err := net.SplitHostPort(w.RemoteAddr().String())
|
||||
ip := net.ParseIP(remoteIP)
|
||||
if err != nil && ip != nil {
|
||||
log.Printf("Faild to parse remote IP WTF? :%s\n", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it is a ACME DNS-01 challange
|
||||
if config.Lego.Enable && handleACMERequest(w, r, acmeList) {
|
||||
return
|
||||
}
|
||||
|
||||
// find out what view to handle the request
|
||||
zoneIndex := -1
|
||||
|
||||
for i, zone := range zones {
|
||||
if (len(zone.acl) == 0 && zoneIndex == -1) || checkACL(zone.acl, aclList, ip) {
|
||||
zoneIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
// No view found that can handle the request
|
||||
if zoneIndex == -1 {
|
||||
rcodeRequest(w, r, dns.RcodeRefused)
|
||||
return
|
||||
}
|
||||
|
||||
handleRequest(w, r, zones[zoneIndex])
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package main
|
||||
package cooldns
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
91
rr_test.go
91
rr_test.go
@@ -1,91 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func init() {
|
||||
start("test/rrConfig.yaml")
|
||||
}
|
||||
|
||||
// Helper
|
||||
|
||||
func request(name string, rrType uint16) (*dns.Msg, error) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(name), rrType)
|
||||
return dns.Exchange(m, "127.0.0.1:8053")
|
||||
}
|
||||
|
||||
func containsA(haystack []dns.RR, name, ip string) bool {
|
||||
searchIP := net.ParseIP(ip)
|
||||
for _, v := range haystack {
|
||||
if v.Header().Name == dns.Fqdn(name) && v.Header().Rrtype == dns.TypeA {
|
||||
if t, ok := v.(*dns.A); ok {
|
||||
if t.A.Equal(searchIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func containsAAAA(haystack []dns.RR, name, ip string) bool {
|
||||
searchIP := net.ParseIP(ip)
|
||||
for _, v := range haystack {
|
||||
if v.Header().Name == dns.Fqdn(name) && v.Header().Rrtype == dns.TypeAAAA {
|
||||
if t, ok := v.(*dns.AAAA); ok {
|
||||
if t.AAAA.Equal(searchIP) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func TestNormalA(t *testing.T) {
|
||||
res, err := request("example.com", dns.TypeA)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !containsA(res.Answer, "example.com", "1.2.3.1") || !containsA(res.Answer, "example.com", "1.2.3.2") {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalAAAA(t *testing.T) {
|
||||
res, err := request("example.com", dns.TypeAAAA)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !containsAAAA(res.Answer, "example.com", "2001:db8:10::1") {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalSOA(t *testing.T) {
|
||||
res, err := request("example.com", dns.TypeSOA)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if len(res.Answer) != 1 {
|
||||
t.Fatalf("Should only be 1 SOA got %d", len(res.Answer))
|
||||
}
|
||||
|
||||
if soa, ok := res.Answer[0].(*dns.SOA); ok {
|
||||
if soa.Ns != "ns.example.com." {
|
||||
t.Fatal("Wrong SOA rr")
|
||||
}
|
||||
} else {
|
||||
t.Fatal("Answer is not a SOA rr")
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
zones:
|
||||
- zone: example.com.
|
||||
file: zonefile.txt
|
||||
|
||||
address: 0.0.0.0:8053
|
||||
@@ -1,20 +0,0 @@
|
||||
$ORIGIN example.com.
|
||||
$TTL 3600
|
||||
example.com. IN SOA ns.example.com. username.example.com. ( 2020091025 7200 3600 1209600 3600 )
|
||||
example.com. IN NS ns
|
||||
example.com. IN NS ns.somewhere.example.
|
||||
example.com. IN MX 10 mail.example.com.
|
||||
@ IN MX 20 mail2.example.com.
|
||||
@ IN MX 50 mail3
|
||||
example.com. IN A 1.2.3.1
|
||||
example.com. IN A 1.2.3.2
|
||||
IN AAAA 2001:db8:10::1
|
||||
ns IN A 1.2.3.3
|
||||
IN AAAA 2001:db8:10::2
|
||||
www IN CNAME example.com.
|
||||
wwwtest IN CNAME www
|
||||
mail IN A 1.2.3.4
|
||||
mail2 IN A 1.2.3.5
|
||||
mail3 IN A 1.2.3.6
|
||||
*.www IN A 1.2.3.7
|
||||
a.www IN A 1.2.3.8
|
||||
Reference in New Issue
Block a user