Skip to content

Commit

Permalink
Fix the handling for concurrent queries over UDP
Browse files Browse the repository at this point in the history
Signed-off-by: Santhosh Manohar <santhosh@docker.com>
  • Loading branch information
Santhosh Manohar committed Apr 1, 2016
1 parent f7e3338 commit f113c9a
Showing 1 changed file with 76 additions and 10 deletions.
86 changes: 76 additions & 10 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,14 @@ const (
defaultRespSize = 512
maxConcurrent = 50
logInterval = 2 * time.Second
maxDNSID = 65536
)

type clientConn struct {
dnsID uint16
respWriter dns.ResponseWriter
}

type extDNSEntry struct {
ipStr string
extConn net.Conn
Expand All @@ -69,6 +75,7 @@ type resolver struct {
count int32
tStamp time.Time
queryLock sync.Mutex
client map[uint16]clientConn
}

func init() {
Expand All @@ -78,8 +85,9 @@ func init() {
// NewResolver creates a new instance of the Resolver
func NewResolver(sb *sandbox) Resolver {
return &resolver{
sb: sb,
err: fmt.Errorf("setup not done yet"),
sb: sb,
err: fmt.Errorf("setup not done yet"),
client: make(map[uint16]clientConn),
}
}

Expand Down Expand Up @@ -375,7 +383,9 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
extConn.SetDeadline(time.Now().Add(extIOTimeout))
co := &dns.Conn{Conn: extConn}

if r.concurrentQueryInc() == false {
// forwardQueryStart stores required context to mux multiple client queries over
// one connection; and limits the number of outstanding concurrent queries.
if r.forwardQueryStart(w, query) == false {
old := r.tStamp
r.tStamp = time.Now()
if r.tStamp.Sub(old) > logInterval {
Expand All @@ -391,18 +401,25 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
}()
err = co.WriteMsg(query)
if err != nil {
r.concurrentQueryDec()
r.forwardQueryEnd(w, query)
log.Debugf("Send to DNS server failed, %s", err)
continue
}

resp, err = co.ReadMsg()
r.concurrentQueryDec()
if err != nil {
r.forwardQueryEnd(w, query)
log.Debugf("Read from DNS server failed, %s", err)
continue
}

// Retrieves the context for the forwarded query and returns the client connection
// to send the reply to
w = r.forwardQueryEnd(w, resp)
if w == nil {
continue
}

resp.Compress = true
break
}
Expand All @@ -418,22 +435,71 @@ func (r *resolver) ServeDNS(w dns.ResponseWriter, query *dns.Msg) {
}
}

func (r *resolver) concurrentQueryInc() bool {
func (r *resolver) forwardQueryStart(w dns.ResponseWriter, msg *dns.Msg) bool {
proto := w.LocalAddr().Network()
dnsID := uint16(rand.Intn(maxDNSID))

cc := clientConn{
dnsID: msg.Id,
respWriter: w,
}

r.queryLock.Lock()
defer r.queryLock.Unlock()

if r.count == maxConcurrent {
return false
}
r.count++

switch proto {
case "tcp":
break
case "udp":
for ok := true; ok == true; dnsID = uint16(rand.Intn(maxDNSID)) {
_, ok = r.client[dnsID]
}
log.Debugf("client dns id %v, changed id %v", msg.Id, dnsID)
r.client[dnsID] = cc
msg.Id = dnsID
default:
log.Errorf("Invalid protocol..")
return false
}

return true
}

func (r *resolver) concurrentQueryDec() bool {
func (r *resolver) forwardQueryEnd(w dns.ResponseWriter, msg *dns.Msg) dns.ResponseWriter {
var (
cc clientConn
ok bool
)
proto := w.LocalAddr().Network()

r.queryLock.Lock()
defer r.queryLock.Unlock()

if r.count == 0 {
return false
log.Errorf("Invalid concurrent query count")
} else {
r.count--
}
r.count--
return true

switch proto {
case "tcp":
break
case "udp":
if cc, ok = r.client[msg.Id]; ok == false {
log.Debugf("Can't retrieve client context for dns id %v", msg.Id)
return nil
}
delete(r.client, msg.Id)
msg.Id = cc.dnsID
w = cc.respWriter
default:
log.Errorf("Invalid protocol")
return nil
}
return w
}

0 comments on commit f113c9a

Please sign in to comment.