From 257fcef0b817c2cd461640dc6e9a808a3257f942 Mon Sep 17 00:00:00 2001 From: KaitoHH Date: Sun, 30 Apr 2023 12:18:20 +0800 Subject: [PATCH] Fix: adjust DNS TTL values based on minimum value (#2706) This commit adds an updated function that adjusts the TTL values of DNS records are based on the minimum TTL the value found in the records list so that all records share the same TTL value. This ensures consistency in the cache expiry time for all records to prevent caching issues. --- dns/resolver.go | 3 ++- dns/util.go | 29 ++++++++++++++++++++++++++--- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/dns/resolver.go b/dns/resolver.go index fa20764..50e0c3d 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -143,7 +143,8 @@ func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, e setMsgTTL(msg, uint32(1)) // Continue fetch go r.exchangeWithoutCache(ctx, m) } else { - setMsgTTL(msg, uint32(time.Until(expireTime).Seconds())) + // updating TTL by subtracting common delta time from each DNS record + updateMsgTTL(msg, uint32(time.Until(expireTime).Seconds())) } return } diff --git a/dns/util.go b/dns/util.go index 90d2ff9..0600360 100644 --- a/dns/util.go +++ b/dns/util.go @@ -14,8 +14,25 @@ import ( "github.com/Dreamacro/clash/log" D "github.com/miekg/dns" + "github.com/samber/lo" ) +func minimalTTL(records []D.RR) uint32 { + return lo.MinBy(records, func(r1 D.RR, r2 D.RR) bool { + return r1.Header().Ttl < r2.Header().Ttl + }).Header().Ttl +} + +func updateTTL(records []D.RR, ttl uint32) { + if len(records) == 0 { + return + } + delta := minimalTTL(records) - ttl + for i := range records { + records[i].Header().Ttl = lo.Clamp(records[i].Header().Ttl-delta, 1, records[i].Header().Ttl) + } +} + func putMsgToCache(c *cache.LruCache, key string, q D.Question, msg *D.Msg) { // skip dns cache for acme challenge if q.Qtype == D.TypeTXT && strings.HasPrefix(q.Name, "_acme-challenge.") { @@ -26,11 +43,11 @@ func putMsgToCache(c *cache.LruCache, key string, q D.Question, msg *D.Msg) { var ttl uint32 switch { case len(msg.Answer) != 0: - ttl = msg.Answer[0].Header().Ttl + ttl = minimalTTL(msg.Answer) case len(msg.Ns) != 0: - ttl = msg.Ns[0].Header().Ttl + ttl = minimalTTL(msg.Ns) case len(msg.Extra) != 0: - ttl = msg.Extra[0].Header().Ttl + ttl = minimalTTL(msg.Extra) default: log.Debugln("[DNS] response msg empty: %#v", msg) return @@ -53,6 +70,12 @@ func setMsgTTL(msg *D.Msg, ttl uint32) { } } +func updateMsgTTL(msg *D.Msg, ttl uint32) { + updateTTL(msg.Answer, ttl) + updateTTL(msg.Ns, ttl) + updateTTL(msg.Extra, ttl) +} + func isIPRequest(q D.Question) bool { return q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA) }