general unix domain sockets support (#45)

* feat: allow binding to unix domain sockets

this is useful when the user does not want to expose more tcp ports than
needed. also simplifes configuration in some situation, like with nixos
modules as the socket paths can be automatically configured.

docs updated with additional configuration flags.

Signed-off-by: Cassie Cheung <me@soopy.moe>

* feat: graceful shutdown and cleanup on signal

this is needed to clean up left-over unix sockets, else on the next boot
listener panics with `address already in use`.

Co-authored-by: cat <cat@gensokyo.uk>
Signed-off-by: Cassie Cheung <me@soopy.moe>

* feat: support unix socket upstream targets

adds support for proxying unix socket upstreams, essentially allowing
anubis to run without listening on tcp sockets at all*.

*for metrics, neither prometheus and victoriametrics supports scraping
from unix sockets. if metrics are desired, tcp sockets are still needed.

Co-authored-by: cat <cat@gensokyo.uk>
Signed-off-by: Cassie Cheung <me@soopy.moe>

* docs: add changelog entry

---------

Signed-off-by: Cassie Cheung <me@soopy.moe>
Co-authored-by: cat <cat@gensokyo.uk>
This commit is contained in:
soopyc 2025-03-21 22:58:05 +08:00 committed by GitHub
parent d93adbc111
commit 1c00431098
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 135 additions and 17 deletions

View file

@ -1,6 +1,7 @@
package main
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
@ -15,12 +16,16 @@ import (
"log/slog"
"math"
mrand "math/rand"
"net"
"net/http"
"net/http/httputil"
"net/url"
"os"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/TecharoHQ/anubis"
@ -37,9 +42,12 @@ import (
)
var (
bind = flag.String("bind", ":8923", "TCP port to bind HTTP to")
bind = flag.String("bind", ":8923", "network address to bind HTTP to")
bindNetwork = flag.String("bind-network", "tcp", "network family to bind HTTP to, e.g. unix, tcp")
challengeDifficulty = flag.Int("difficulty", 4, "difficulty of the challenge")
metricsBind = flag.String("metrics-bind", ":9090", "TCP port to bind metrics to")
metricsBind = flag.String("metrics-bind", ":9090", "network address to bind metrics to")
metricsBindNetwork = flag.String("metrics-bind-network", "tcp", "network family for the metrics server to bind to")
socketMode = flag.String("socket-mode", "0770", "socket mode (permissions) for unix domain sockets.")
robotsTxt = flag.Bool("serve-robots-txt", false, "serve a robots.txt file that disallows all robots")
policyFname = flag.String("policy-fname", "", "full path to anubis policy document (defaults to a sensible built-in policy)")
slogLevel = flag.String("slog-level", "INFO", "logging level (see https://pkg.go.dev/log/slog#hdr-Levels)")
@ -101,6 +109,40 @@ func doHealthCheck() error {
return nil
}
func setupListener(network string, address string) (net.Listener, string) {
formattedAddress := ""
switch network {
case "unix":
formattedAddress = "unix:" + address
case "tcp":
formattedAddress = "http://localhost" + address
default:
formattedAddress = fmt.Sprintf(`(%s) %s`, network, address)
}
listener, err := net.Listen(network, address)
if err != nil {
log.Fatal(fmt.Errorf("failed to bind to %s: %w", formattedAddress, err))
}
// additional permission handling for unix sockets
if network == "unix" {
mode, err := strconv.ParseUint(*socketMode, 8, 0)
if err != nil {
listener.Close()
log.Fatal(fmt.Errorf("could not parse socket mode %s: %w", *socketMode, err))
}
err = os.Chmod(address, os.FileMode(mode))
if err != nil {
listener.Close()
log.Fatal(fmt.Errorf("could not change socket mode: %w", err))
}
}
return listener, formattedAddress
}
func main() {
flagenv.Parse()
flag.Parse()
@ -155,20 +197,59 @@ func main() {
})
}
wg := new(sync.WaitGroup)
// install signal handler
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
if *metricsBind != "" {
go metricsServer()
wg.Add(1)
go metricsServer(ctx, wg.Done)
}
mux.HandleFunc("/", s.maybeReverseProxy)
slog.Info("listening", "url", "http://localhost"+*bind, "difficulty", *challengeDifficulty, "serveRobotsTXT", *robotsTxt, "target", *target, "version", anubis.Version)
log.Fatal(http.ListenAndServe(*bind, mux))
srv := http.Server{Handler: mux}
listener, url := setupListener(*bindNetwork, *bind)
slog.Info("listening", "url", url, "difficulty", *challengeDifficulty, "serveRobotsTXT", *robotsTxt, "target", *target, "version", anubis.Version)
go func() {
<-ctx.Done()
c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(c); err != nil {
log.Printf("cannot shut down: %v", err)
}
}()
if err := srv.Serve(listener); err != http.ErrServerClosed {
log.Fatal(err)
}
wg.Wait()
}
func metricsServer() {
http.DefaultServeMux.Handle("/metrics", promhttp.Handler())
slog.Debug("listening for metrics", "url", "http://localhost"+*metricsBind)
log.Fatal(http.ListenAndServe(*metricsBind, nil))
func metricsServer(ctx context.Context, done func()) {
defer done()
mux := http.NewServeMux()
mux.Handle("/metrics", promhttp.Handler())
srv := http.Server{Handler: mux}
listener, url := setupListener(*metricsBindNetwork, *metricsBind)
slog.Debug("listening for metrics", "url", url)
go func() {
<-ctx.Done()
c, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := srv.Shutdown(c); err != nil {
log.Printf("cannot shut down: %v", err)
}
}()
if err := srv.Serve(listener); err != http.ErrServerClosed {
log.Fatal(err)
}
}
func sha256sum(text string) (string, error) {
@ -207,7 +288,24 @@ func New(target, policyFname string) (*Server, error) {
return nil, fmt.Errorf("failed to generate ed25519 key: %w", err)
}
transport := http.DefaultTransport.(*http.Transport).Clone()
// https://github.com/oauth2-proxy/oauth2-proxy/blob/4e2100a2879ef06aea1411790327019c1a09217c/pkg/upstream/http.go#L124
if u.Scheme == "unix" {
// clean path up so we don't use the socket path in proxied requests
addr := u.Path
u.Path = ""
// tell transport how to dial unix sockets
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
dialer := net.Dialer{}
return dialer.DialContext(ctx, "unix", addr)
}
// tell transport how to handle the unix url scheme
transport.RegisterProtocol("unix", unixRoundTripper{Transport: transport})
}
rp := httputil.NewSingleHostReverseProxy(u)
rp.Transport = transport
var fin io.ReadCloser
@ -240,6 +338,22 @@ func New(target, policyFname string) (*Server, error) {
}, nil
}
// https://github.com/oauth2-proxy/oauth2-proxy/blob/master/pkg/upstream/http.go#L124
type unixRoundTripper struct {
Transport *http.Transport
}
// set bare minimum stuff
func (t unixRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
req = req.Clone(req.Context())
if req.Host == "" {
req.Host = "localhost"
}
req.URL.Host = req.Host // proxy error: no Host in request URL
req.URL.Scheme = "http" // make http.Transport happy and avoid an infinite recursion
return t.Transport.RoundTrip(req)
}
type Server struct {
rp *httputil.ReverseProxy
priv ed25519.PrivateKey