// Copyright (c) 2010 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/base/dnsrr_resolver.h"

#if defined(OS_POSIX)
#include <resolv.h>
#endif

#include "base/lock.h"
#include "base/message_loop.h"
#include "base/scoped_ptr.h"
#include "base/singleton.h"
#include "base/stl_util-inl.h"
#include "base/string_piece.h"
#include "base/task.h"
#include "base/worker_pool.h"
#include "net/base/dns_reload_timer.h"
#include "net/base/dns_util.h"
#include "net/base/net_errors.h"

DISABLE_RUNNABLE_METHOD_REFCOUNT(net::RRResolverWorker);
DISABLE_RUNNABLE_METHOD_REFCOUNT(net::RRResolverHandle);

// Life of a query:
//
// DnsRRResolver RRResolverJob RRResolverWorker       ...         Handle
//      |                       (origin loop)    (worker loop)
//      |
//   Resolve()
//      |---->----<creates>
//      |
//      |---->-------------------<creates>
//      |
//      |---->---------------------------------------------------<creates>
//      |
//      |---->--------------------Start
//      |                           |
//      |                        PostTask
//      |
//      |                                     <starts resolving>
//      |---->-----AddHandle                          |
//                                                    |
//                                                    |
//                                                    |
//                                                  Finish
//                                                    |
//                                                 PostTask
//
//                                   |
//                                DoReply
//      |----<-----------------------|
//  HandleResult
//      |
//      |---->-----HandleResult
//                      |
//                      |------>-----------------------------------Post
//
//
//
// A cache hit:
//
// DnsRRResolver CacheHitCallbackTask  Handle
//      |
//   Resolve()
//      |---->----<creates>
//      |
//      |---->------------------------<creates>
//      |
//      |
//   PostTask
//
// (MessageLoop cycles)
//
//                   Run
//                    |
//                    |----->-----------Post



namespace net {

static const uint16 kClassIN = 1;
// kMaxCacheEntries is the number of RRResponse object that we'll cache.
static const unsigned kMaxCacheEntries = 32;
// kNegativeTTLSecs is the number of seconds for which we'll cache a negative
// cache entry.
static const unsigned kNegativeTTLSecs = 60;

RRResponse::RRResponse()
    : ttl(0), dnssec(false), negative(false) {
}

RRResponse::~RRResponse() {}

class RRResolverHandle {
 public:
  RRResolverHandle(CompletionCallback* callback, RRResponse* response)
      : callback_(callback),
        response_(response) {
  }

  // Cancel ensures that the result callback will never be made.
  void Cancel() {
    callback_ = NULL;
  }

  // Post copies the contents of |response| to the caller's RRResponse and
  // calls the callback.
  void Post(int rv, const RRResponse* response) {
    if (!callback_)
      return;  // we were canceled.

    if (response_ && response)
      *response_ = *response;
    callback_->Run(rv);
    delete this;
  }

 private:
  friend class RRResolverWorker;
  friend class DnsRRResolver;

  CompletionCallback* callback_;
  RRResponse* const response_;
};


// RRResolverWorker runs on a worker thread and takes care of the blocking
// process of performing the DNS resolution.
class RRResolverWorker {
 public:
  RRResolverWorker(const std::string& name, uint16 rrtype, uint16 flags,
                   DnsRRResolver* dnsrr_resolver)
      : name_(name),
        rrtype_(rrtype),
        flags_(flags),
        origin_loop_(MessageLoop::current()),
        dnsrr_resolver_(dnsrr_resolver),
        canceled_(false) {
  }

  bool Start() {
    DCHECK_EQ(MessageLoop::current(), origin_loop_);

    return WorkerPool::PostTask(
               FROM_HERE, NewRunnableMethod(this, &RRResolverWorker::Run),
               true /* task is slow */);
  }

  // Cancel is called from the origin loop when the DnsRRResolver is getting
  // deleted.
  void Cancel() {
    DCHECK_EQ(MessageLoop::current(), origin_loop_);
    AutoLock locked(lock_);
    canceled_ = true;
  }

 private:

#if defined(OS_POSIX)

  virtual void Run() {
    // Runs on a worker thread.

    if (HandleTestCases()) {
      Finish();
      return;
    }

    bool r = true;
    if ((_res.options & RES_INIT) == 0) {
      if (res_ninit(&_res) != 0)
        r = false;
    }

    if (r) {
      unsigned long saved_options = _res.options;
      r = Do();

#if defined(OS_POSIX) && !defined(OS_MACOSX) && !defined(OS_OPENBSD)
      if (!r && DnsReloadTimerHasExpired()) {
        res_nclose(&_res);
        if (res_ninit(&_res) == 0)
          r = Do();
      }
#endif
      _res.options = saved_options;
    }

    response_.fetch_time = base::Time::Now();

    if (r) {
      result_ = OK;
    } else {
      result_ = ERR_NAME_NOT_RESOLVED;
      response_.negative = true;
      response_.ttl = kNegativeTTLSecs;
    }

    Finish();
  }

  bool Do() {
    // For DNSSEC, a 4K buffer is suggested
    static const unsigned kMaxDNSPayload = 4096;

#ifndef RES_USE_DNSSEC
    // Some versions of libresolv don't have support for the DO bit. In this
    // case, we proceed without it.
    static const int RES_USE_DNSSEC = 0;
#endif

#ifndef RES_USE_EDNS0
    // Some versions of glibc are so old that they don't support EDNS0 either.
    // http://code.google.com/p/chromium/issues/detail?id=51676
    static const int RES_USE_EDNS0 = 0;
#endif

    // We set the options explicitly. Note that this removes several default
    // options: RES_DEFNAMES and RES_DNSRCH (see res_init(3)).
    _res.options = RES_INIT | RES_RECURSE | RES_USE_EDNS0 | RES_USE_DNSSEC;
    uint8 answer[kMaxDNSPayload];
    int len = res_search(name_.c_str(), kClassIN, rrtype_, answer,
                         sizeof(answer));
    if (len == -1)
      return false;

    return response_.ParseFromResponse(answer, len, rrtype_);
  }

#else  // OS_WIN

  virtual void Run() {
    if (HandleTestCases()) {
      Finish();
      return;
    }

    response_.fetch_time = base::Time::Now();
    response_.negative = true;
    result_ = ERR_NAME_NOT_RESOLVED;
    Finish();
  }

#endif // OS_WIN

  // HandleTestCases stuffs in magic test values in the event that the query is
  // from a unittest.
  bool HandleTestCases() {
    if (rrtype_ == kDNS_TESTING) {
      response_.fetch_time = base::Time::Now();

      if (name_ == "www.testing.notatld") {
        response_.ttl = 86400;
        response_.negative = false;
        response_.rrdatas.push_back("goats!");
        result_ = OK;
        return true;
      } else if (name_ == "nx.testing.notatld") {
        response_.negative = true;
        result_ = ERR_NAME_NOT_RESOLVED;
        return true;
      }
    }

    return false;
  }

  // DoReply runs on the origin thread.
  void DoReply() {
    DCHECK_EQ(MessageLoop::current(), origin_loop_);
    {
      // We lock here because the worker thread could still be in Finished,
      // after the PostTask, but before unlocking |lock_|. In this case, we end
      // up deleting a locked Lock, which can lead to memory leaks.
      AutoLock locked(lock_);
      if (!canceled_)
        dnsrr_resolver_->HandleResult(name_, rrtype_, result_, response_);
    }
    delete this;
  }

  void Finish() {
    // Runs on the worker thread.
    // We assume that the origin loop outlives the DnsRRResolver. If the
    // DnsRRResolver is deleted, it will call Cancel on us. If it does so
    // before the Acquire, we'll delete ourselves and return. If it's trying to
    // do so concurrently, then it'll block on the lock and we'll call PostTask
    // while the DnsRRResolver (and therefore the MessageLoop) is still alive.
    // If it does so after this function, we assume that the MessageLoop will
    // process pending tasks. In which case we'll notice the |canceled_| flag
    // in DoReply.

    bool canceled;
    {
      AutoLock locked(lock_);
      canceled = canceled_;
      if (!canceled) {
        origin_loop_->PostTask(
            FROM_HERE, NewRunnableMethod(this, &RRResolverWorker::DoReply));
      }
    }

    if (canceled)
      delete this;
  }

  const std::string name_;
  const uint16 rrtype_;
  const uint16 flags_;
  MessageLoop* const origin_loop_;
  DnsRRResolver* const dnsrr_resolver_;

  Lock lock_;
  bool canceled_;

  int result_;
  RRResponse response_;

  DISALLOW_COPY_AND_ASSIGN(RRResolverWorker);
};


// A Buffer is used for walking over a DNS packet.
class Buffer {
 public:
  Buffer(const uint8* p, unsigned len)
      : p_(p),
        packet_(p),
        len_(len),
        packet_len_(len) {
  }

  bool U8(uint8* v) {
    if (len_ < 1)
      return false;
    *v = *p_;
    p_++;
    len_--;
    return true;
  }

  bool U16(uint16* v) {
    if (len_ < 2)
      return false;
    *v = static_cast<uint16>(p_[0]) << 8 |
         static_cast<uint16>(p_[1]);
    p_ += 2;
    len_ -= 2;
    return true;
  }

  bool U32(uint32* v) {
    if (len_ < 4)
      return false;
    *v = static_cast<uint32>(p_[0]) << 24 |
         static_cast<uint32>(p_[1]) << 16 |
         static_cast<uint32>(p_[2]) << 8 |
         static_cast<uint32>(p_[3]);
    p_ += 4;
    len_ -= 4;
    return true;
  }

  bool Skip(unsigned n) {
    if (len_ < n)
      return false;
    p_ += n;
    len_ -= n;
    return true;
  }

  bool Block(base::StringPiece* out, unsigned len) {
    if (len_ < len)
      return false;
    *out = base::StringPiece(reinterpret_cast<const char*>(p_), len);
    p_ += len;
    len_ -= len;
    return true;
  }

  // DNSName parses a (possibly compressed) DNS name from the packet. If |name|
  // is not NULL, then the name is written into it. See RFC 1035 section 4.1.4.
  bool DNSName(std::string* name) {
    unsigned jumps = 0;
    const uint8* p = p_;
    unsigned len = len_;

    if (name)
      name->clear();

    for (;;) {
      if (len < 1)
        return false;
      uint8 d = *p;
      p++;
      len--;

      // The two couple of bits of the length give the type of the length. It's
      // either a direct length or a pointer to the remainder of the name.
      if ((d & 0xc0) == 0xc0) {
        // This limit matches the depth limit in djbdns.
        if (jumps > 100)
          return false;
        if (len < 1)
          return false;
        uint16 offset = static_cast<uint16>(d) << 8 |
                        static_cast<uint16>(p[0]);
        offset &= 0x3ff;
        p++;
        len--;

        if (jumps == 0) {
          p_ = p;
          len_ = len;
        }
        jumps++;

        if (offset >= packet_len_)
          return false;
        p = &packet_[offset];
      } else if ((d & 0xc0) == 0) {
        uint8 label_len = d;
        if (len < label_len)
          return false;
        if (name && label_len) {
          if (!name->empty())
            name->append(".");
          name->append(reinterpret_cast<const char*>(p), label_len);
        }
        p += label_len;
        len -= label_len;

        if (jumps == 0) {
          p_ = p;
          len_ = len;
        }

        if (label_len == 0)
          break;
      } else {
        return false;
      }
    }

    return true;
  }

 private:
  DISALLOW_COPY_AND_ASSIGN(Buffer);

  const uint8* p_;
  const uint8* const packet_;
  unsigned len_;
  const unsigned packet_len_;
};

bool RRResponse::HasExpired(const base::Time current_time) const {
  const base::TimeDelta delta(base::TimeDelta::FromSeconds(ttl));
  const base::Time expiry = fetch_time + delta;
  return current_time >= expiry;
}

bool RRResponse::ParseFromResponse(const uint8* p, unsigned len,
                                   uint16 rrtype_requested) {
#if defined(OS_POSIX)
  name.clear();
  ttl = 0;
  dnssec = false;
  negative = false;
  rrdatas.clear();
  signatures.clear();

  // RFC 1035 section 4.4.1
  uint8 flags2;
  Buffer buf(p, len);
  if (!buf.Skip(2) ||  // skip id
      !buf.Skip(1) ||  // skip first flags byte
      !buf.U8(&flags2)) {
    return false;
  }

  // Bit 5 is the Authenticated Data (AD) bit. See
  // http://tools.ietf.org/html/rfc2535#section-6.1
  if (flags2 & 32) {
    // AD flag is set. We'll trust it if it came from a local nameserver.
    // Currently the resolv structure is IPv4 only, so we can't test for IPv6
    // loopback addresses.
    if (_res.nscount == 1 &&
        memcmp(&_res.nsaddr_list[0].sin_addr,
               "\x7f\x00\x00\x01" /* 127.0.0.1 */, 4) == 0) {
      dnssec = true;
    }
  }

  uint16 query_count, answer_count, authority_count, additional_count;
  if (!buf.U16(&query_count) ||
      !buf.U16(&answer_count) ||
      !buf.U16(&authority_count) ||
      !buf.U16(&additional_count)) {
    return false;
  }

  if (query_count != 1)
    return false;

  uint16 type, klass;
  if (!buf.DNSName(NULL) ||
      !buf.U16(&type) ||
      !buf.U16(&klass) ||
      type != rrtype_requested ||
      klass != kClassIN) {
    return false;
  }

  if (answer_count < 1)
    return false;

  for (uint32 i = 0; i < answer_count; i++) {
    std::string* name = NULL;
    if (i == 0)
      name = &this->name;
    uint32 ttl;
    uint16 rrdata_len;
    if (!buf.DNSName(name) ||
        !buf.U16(&type) ||
        !buf.U16(&klass) ||
        !buf.U32(&ttl) ||
        !buf.U16(&rrdata_len)) {
      return false;
    }

    base::StringPiece rrdata;
    if (!buf.Block(&rrdata, rrdata_len))
      return false;

    if (klass == kClassIN && type == rrtype_requested) {
      if (i == 0)
        this->ttl = ttl;
      rrdatas.push_back(std::string(rrdata.data(), rrdata.size()));
    } else if (klass == kClassIN && type == kDNS_RRSIG) {
      signatures.push_back(std::string(rrdata.data(), rrdata.size()));
    }
  }
#endif  // defined(OS_POSIX)

  return true;
}


// An RRResolverJob is a one-to-one counterpart of an RRResolverWorker. It
// lives only on the DnsRRResolver's origin message loop.
class RRResolverJob {
 public:
  RRResolverJob(RRResolverWorker* worker)
      : worker_(worker) {
  }

  ~RRResolverJob() {
    Cancel(ERR_NAME_NOT_RESOLVED);
  }

  void AddHandle(RRResolverHandle* handle) {
    handles_.push_back(handle);
  }

  void HandleResult(int result, const RRResponse& response) {
    worker_ = NULL;
    PostAll(result, &response);
  }

  void Cancel(int error) {
    if (worker_) {
      worker_->Cancel();
      worker_ = NULL;
    }

    PostAll(error, NULL);
  }

 private:
  void PostAll(int result, const RRResponse* response) {
    std::vector<RRResolverHandle*> handles;
    handles_.swap(handles);

    for (std::vector<RRResolverHandle*>::iterator
         i = handles.begin(); i != handles.end(); i++) {
      (*i)->Post(result, response);
      // Post() causes the RRResolverHandle to delete itself.
    }
  }

  std::vector<RRResolverHandle*> handles_;
  RRResolverWorker* worker_;
};


DnsRRResolver::DnsRRResolver()
    : requests_(0),
      cache_hits_(0),
      inflight_joins_(0),
      in_destructor_(false) {
}

DnsRRResolver::~DnsRRResolver() {
  DCHECK(!in_destructor_);
  in_destructor_ = true;
  STLDeleteValues(&inflight_);
}

intptr_t DnsRRResolver::Resolve(const std::string& name, uint16 rrtype,
                                uint16 flags, CompletionCallback* callback,
                                RRResponse* response,
                                int priority /* ignored */,
                                const BoundNetLog& netlog /* ignored */) {
  DCHECK(CalledOnValidThread());
  DCHECK(!in_destructor_);

  if (!callback || !response || name.empty())
    return kInvalidHandle;

  // Don't allow queries of type ANY
  if (rrtype == kDNS_ANY)
    return kInvalidHandle;

  requests_++;

  const std::pair<std::string, uint16> key(make_pair(name, rrtype));
  // First check the cache.
  std::map<std::pair<std::string, uint16>, RRResponse>::iterator i;
  i = cache_.find(key);
  if (i != cache_.end()) {
    if (!i->second.HasExpired(base::Time::Now())) {
      int error;
      if (i->second.negative) {
        error = ERR_NAME_NOT_RESOLVED;
      } else {
        error = OK;
        *response = i->second;
      }
      RRResolverHandle* handle = new RRResolverHandle(
          callback, NULL /* no response pointer because we've already filled */
                         /* it in */);
      cache_hits_++;
      // We need a typed NULL pointer in order to make the templates work out.
      static const RRResponse* kNoResponse = NULL;
      MessageLoop::current()->PostTask(
          FROM_HERE, NewRunnableMethod(handle, &RRResolverHandle::Post, error,
                                       kNoResponse));
      return reinterpret_cast<intptr_t>(handle);
    } else {
      // entry has expired.
      cache_.erase(i);
    }
  }

  // No cache hit. See if a request is currently in flight.
  RRResolverJob* job;
  std::map<std::pair<std::string, uint16>, RRResolverJob*>::const_iterator j;
  j = inflight_.find(key);
  if (j != inflight_.end()) {
    // The request is in flight already. We'll just attach our callback.
    inflight_joins_++;
    job = j->second;
  } else {
    // Need to make a new request.
    RRResolverWorker* worker = new RRResolverWorker(name, rrtype, flags, this);
    job = new RRResolverJob(worker);
    inflight_.insert(make_pair(key, job));
    if (!worker->Start()) {
      delete job;
      delete worker;
      return kInvalidHandle;
    }
  }

  RRResolverHandle* handle = new RRResolverHandle(callback, response);
  job->AddHandle(handle);
  return reinterpret_cast<intptr_t>(handle);
}

void DnsRRResolver::CancelResolve(intptr_t h) {
  DCHECK(CalledOnValidThread());
  RRResolverHandle* handle = reinterpret_cast<RRResolverHandle*>(h);
  handle->Cancel();
}

void DnsRRResolver::OnIPAddressChanged() {
  DCHECK(CalledOnValidThread());
  DCHECK(!in_destructor_);

  std::map<std::pair<std::string, uint16>, RRResolverJob*> inflight;
  inflight.swap(inflight_);
  cache_.clear();

  for (std::map<std::pair<std::string, uint16>, RRResolverJob*>::iterator
       i = inflight.begin(); i != inflight.end(); i++) {
    i->second->Cancel(ERR_ABORTED);
    delete i->second;
  }
}

// HandleResult is called on the origin message loop.
void DnsRRResolver::HandleResult(const std::string& name, uint16 rrtype,
                                 int result, const RRResponse& response) {
  DCHECK(CalledOnValidThread());

  const std::pair<std::string, uint16> key(std::make_pair(name, rrtype));

  DCHECK_GE(kMaxCacheEntries, 1u);
  DCHECK_LE(cache_.size(), kMaxCacheEntries);
  if (cache_.size() == kMaxCacheEntries) {
    // need to remove an element of the cache.
    const base::Time current_time(base::Time::Now());
    for (std::map<std::pair<std::string, uint16>, RRResponse>::iterator
         i = cache_.begin(); i != cache_.end(); ++i) {
      if (i->second.HasExpired(current_time)) {
        cache_.erase(i);
        break;
      }
    }
  }
  if (cache_.size() == kMaxCacheEntries) {
    // if we didn't clear out any expired entries, we just remove the first
    // element. Crummy but simple.
    cache_.erase(cache_.begin());
  }

  cache_.insert(std::make_pair(key, response));

  std::map<std::pair<std::string, uint16>, RRResolverJob*>::iterator j;
  j = inflight_.find(key);
  if (j == inflight_.end()) {
    NOTREACHED();
    return;
  }
  RRResolverJob* job = j->second;
  inflight_.erase(j);

  job->HandleResult(result, response);
  delete job;
}

}  // namespace net
