diff --git a/internal/config.go b/internal/config.go new file mode 100644 index 0000000..af57258 --- /dev/null +++ b/internal/config.go @@ -0,0 +1,143 @@ +package cooldns + +import ( + "io/ioutil" + "log" + "net" + "os" + + "github.com/miekg/dns" + "gopkg.in/yaml.v3" +) + +// config format of the config file +type config struct { + Zones []configZone `yaml:"zones"` + ACL []configACL `yaml:"acl"` + Forward configForward `yaml:"forward"` + Address string `yaml:"address"` + Blacklist []configBlacklist `yaml:"blacklist"` + TLS configTLS `yaml:"tls"` + Lego configLego `yaml:"lego"` +} + +type configForward struct { + ACL []string `yaml:"acl"` + Server string `yaml:"server"` +} + +type configACL struct { + Name string `yaml:"name"` + CIDR string `yaml:"cidr"` +} + +type configZone struct { + Zone string `yaml:"zone"` + File string `yaml:"file"` + ACL []string `yaml:"acl"` +} + +type configBlacklist struct { + URL string `yaml:"url"` + Format string `yaml:"format"` +} + +type configTLS struct { + Enable bool `yaml:"enable"` + Address string `yaml:"address"` + Cert string `yaml:"cert"` + Key string `yaml:"key"` +} + +func loadConfig(configPath string) (*config, error) { + file, err := ioutil.ReadFile(configPath) + if err != nil { + return nil, err + } + + var loadedConfig config + + err = yaml.Unmarshal(file, &loadedConfig) + if err != nil { + return nil, err + } + + return &loadedConfig, nil +} + +func loadZones(configZones []configZone) (zoneMap, error) { + zones := make(zoneMap) + for _, z := range configZones { + rrs, err := loadZonefile(z.File, z.Zone) + if err != nil { + return nil, err + } + if zones[z.Zone] == nil { + zones[z.Zone] = make([]zoneView, 0) + } + zones[z.Zone] = append(zones[z.Zone], zoneView{ + rr: createRRMap(rrs), + acl: z.ACL, + }) + + log.Printf("Loaded zone %s\n", z.Zone) + } + + return zones, nil +} + +// createRRMap order the rr into a structure that is more easy to use +func createRRMap(rrs []dns.RR) rrMap { + rrMap := make(rrMap) + for _, rr := range rrs { + if rrMap[rr.Header().Rrtype] == nil { + rrMap[rr.Header().Rrtype] = make(map[string][]dns.RR) + } + + if rrMap[rr.Header().Rrtype][rr.Header().Name] == nil { + rrMap[rr.Header().Rrtype][rr.Header().Name] = make([]dns.RR, 0) + } + rrMap[rr.Header().Rrtype][rr.Header().Name] = append(rrMap[rr.Header().Rrtype][rr.Header().Name], rr) + } + + return rrMap +} + +func loadZonefile(filepath, origin string) ([]dns.RR, error) { + file, err := os.Open(filepath) + + if err != nil { + return nil, err + } + + parser := dns.NewZoneParser(file, origin, filepath) + + var rrs = make([]dns.RR, 0) + + for rr, ok := parser.Next(); ok; rr, ok = parser.Next() { + rrs = append(rrs, rr) + } + + if err := parser.Err(); err != nil { + log.Println(err) + } + + return rrs, nil +} + +// createACLList create a map with the CIDR and the name of the rule +func createACLList(config []configACL) (map[string]*net.IPNet, error) { + acls := make(map[string]*net.IPNet) + + for _, aclRule := range config { + _, mask, err := net.ParseCIDR(aclRule.CIDR) + + if err != nil { + return nil, err + } + + acls[aclRule.Name] = mask + } + + return acls, nil +} diff --git a/internal/cooldns.go b/internal/cooldns.go index b6d93af..f72a676 100644 --- a/internal/cooldns.go +++ b/internal/cooldns.go @@ -1,7 +1,6 @@ package cooldns import ( - "io/ioutil" "log" "net" "os" @@ -9,7 +8,6 @@ import ( "strings" "github.com/miekg/dns" - "gopkg.in/yaml.v3" ) type zoneView struct { @@ -21,45 +19,6 @@ type zoneMap map[string][]zoneView type rrMap map[uint16]map[string][]dns.RR -// config format of the config file -type config struct { - Zones []configZone `yaml:"zones"` - ACL []configACL `yaml:"acl"` - Forward configForward `yaml:"forward"` - Address string `yaml:"address"` - Blacklist []configBlacklist `yaml:"blacklist"` - TLS configTLS `yaml:"tls"` - Lego configLego `yaml:"lego"` -} - -type configForward struct { - ACL []string `yaml:"acl"` - Server string `yaml:"server"` -} - -type configACL struct { - Name string `yaml:"name"` - CIDR string `yaml:"cidr"` -} - -type configZone struct { - Zone string `yaml:"zone"` - File string `yaml:"file"` - ACL []string `yaml:"acl"` -} - -type configBlacklist struct { - URL string `yaml:"url"` - Format string `yaml:"format"` -} - -type configTLS struct { - Enable bool `yaml:"enable"` - Address string `yaml:"address"` - Cert string `yaml:"cert"` - Key string `yaml:"key"` -} - // All record types to send when a ANY request is send var anyRecordTypes = []uint16{ dns.TypeSOA, @@ -115,99 +74,6 @@ func Start(configPath string) { log.Printf("Start listening on udp %s and tcp %s\n", config.Address, config.Address) } -func loadConfig(configPath string) (*config, error) { - file, err := ioutil.ReadFile(configPath) - if err != nil { - return nil, err - } - - var loadedConfig config - - err = yaml.Unmarshal(file, &loadedConfig) - if err != nil { - return nil, err - } - - return &loadedConfig, nil -} - -func loadZones(configZones []configZone) (zoneMap, error) { - zones := make(zoneMap) - for _, z := range configZones { - rrs, err := loadZonefile(z.File, z.Zone) - if err != nil { - return nil, err - } - if zones[z.Zone] == nil { - zones[z.Zone] = make([]zoneView, 0) - } - zones[z.Zone] = append(zones[z.Zone], zoneView{ - rr: createRRMap(rrs), - acl: z.ACL, - }) - - log.Printf("Loaded zone %s\n", z.Zone) - } - - return zones, nil -} - -// createRRMap order the rr into a structure that is more easy to use -func createRRMap(rrs []dns.RR) rrMap { - rrMap := make(rrMap) - for _, rr := range rrs { - if rrMap[rr.Header().Rrtype] == nil { - rrMap[rr.Header().Rrtype] = make(map[string][]dns.RR) - } - - if rrMap[rr.Header().Rrtype][rr.Header().Name] == nil { - rrMap[rr.Header().Rrtype][rr.Header().Name] = make([]dns.RR, 0) - } - rrMap[rr.Header().Rrtype][rr.Header().Name] = append(rrMap[rr.Header().Rrtype][rr.Header().Name], rr) - } - - return rrMap -} - -func loadZonefile(filepath, origin string) ([]dns.RR, error) { - file, err := os.Open(filepath) - - if err != nil { - return nil, err - } - - parser := dns.NewZoneParser(file, origin, filepath) - - var rrs = make([]dns.RR, 0) - - for rr, ok := parser.Next(); ok; rr, ok = parser.Next() { - rrs = append(rrs, rr) - } - - if err := parser.Err(); err != nil { - log.Println(err) - } - - return rrs, nil -} - -// createACLList create a map with the CIDR and the name of the rule -func createACLList(config []configACL) (map[string]*net.IPNet, error) { - acls := make(map[string]*net.IPNet) - - for _, aclRule := range config { - _, mask, err := net.ParseCIDR(aclRule.CIDR) - - if err != nil { - return nil, err - } - - acls[aclRule.Name] = mask - } - - return acls, nil -} - // createServer creates a new serve mux. Adds all the logic to handle the request func createServer(zones zoneMap, config config, aclList map[string]*net.IPNet, blacklist map[string]bool, acmeList *legoMap) *dns.ServeMux { srv := dns.NewServeMux()