dns_tls_test.cpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #define LOG_TAG "dns_tls_test"
  17. #define LOG_NDEBUG 1 // Set to 0 to enable verbose debug logging
  18. #include <gtest/gtest.h>
  19. #include "DnsTlsDispatcher.h"
  20. #include "DnsTlsQueryMap.h"
  21. #include "DnsTlsServer.h"
  22. #include "DnsTlsSessionCache.h"
  23. #include "DnsTlsSocket.h"
  24. #include "DnsTlsTransport.h"
  25. #include "IDnsTlsSocket.h"
  26. #include "IDnsTlsSocketFactory.h"
  27. #include "IDnsTlsSocketObserver.h"
  28. #include "dns_responder/dns_tls_frontend.h"
  29. #include <chrono>
  30. #include <arpa/inet.h>
  31. #include <android-base/macros.h>
  32. #include <netdutils/Slice.h>
  33. #include "log/log.h"
  34. namespace android {
  35. namespace net {
  36. using netdutils::Slice;
  37. using netdutils::makeSlice;
  38. typedef std::vector<uint8_t> bytevec;
  39. static void parseServer(const char* server, in_port_t port, sockaddr_storage* parsed) {
  40. sockaddr_in* sin = reinterpret_cast<sockaddr_in*>(parsed);
  41. if (inet_pton(AF_INET, server, &(sin->sin_addr)) == 1) {
  42. // IPv4 parse succeeded, so it's IPv4
  43. sin->sin_family = AF_INET;
  44. sin->sin_port = htons(port);
  45. return;
  46. }
  47. sockaddr_in6* sin6 = reinterpret_cast<sockaddr_in6*>(parsed);
  48. if (inet_pton(AF_INET6, server, &(sin6->sin6_addr)) == 1){
  49. // IPv6 parse succeeded, so it's IPv6.
  50. sin6->sin6_family = AF_INET6;
  51. sin6->sin6_port = htons(port);
  52. return;
  53. }
  54. ALOGE("Failed to parse server address: %s", server);
  55. }
  56. bytevec FINGERPRINT1 = { 1 };
  57. bytevec FINGERPRINT2 = { 2 };
  58. std::string SERVERNAME1 = "dns.example.com";
  59. std::string SERVERNAME2 = "dns.example.org";
  60. // BaseTest just provides constants that are useful for the tests.
  61. class BaseTest : public ::testing::Test {
  62. protected:
  63. BaseTest() {
  64. parseServer("192.0.2.1", 853, &V4ADDR1);
  65. parseServer("192.0.2.2", 853, &V4ADDR2);
  66. parseServer("2001:db8::1", 853, &V6ADDR1);
  67. parseServer("2001:db8::2", 853, &V6ADDR2);
  68. SERVER1 = DnsTlsServer(V4ADDR1);
  69. SERVER1.fingerprints.insert(FINGERPRINT1);
  70. SERVER1.name = SERVERNAME1;
  71. }
  72. sockaddr_storage V4ADDR1;
  73. sockaddr_storage V4ADDR2;
  74. sockaddr_storage V6ADDR1;
  75. sockaddr_storage V6ADDR2;
  76. DnsTlsServer SERVER1;
  77. };
  78. bytevec make_query(uint16_t id, size_t size) {
  79. bytevec vec(size);
  80. vec[0] = id >> 8;
  81. vec[1] = id;
  82. // Arbitrarily fill the query body with unique data.
  83. for (size_t i = 2; i < size; ++i) {
  84. vec[i] = id + i;
  85. }
  86. return vec;
  87. }
  88. // Query constants
  89. const unsigned MARK = 123;
  90. const uint16_t ID = 52;
  91. const uint16_t SIZE = 22;
  92. const bytevec QUERY = make_query(ID, SIZE);
  93. template <class T>
  94. class FakeSocketFactory : public IDnsTlsSocketFactory {
  95. public:
  96. FakeSocketFactory() {}
  97. std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
  98. const DnsTlsServer& server ATTRIBUTE_UNUSED,
  99. unsigned mark ATTRIBUTE_UNUSED,
  100. IDnsTlsSocketObserver* observer,
  101. DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
  102. return std::make_unique<T>(observer);
  103. }
  104. };
  105. bytevec make_echo(uint16_t id, const Slice query) {
  106. bytevec response(query.size() + 2);
  107. response[0] = id >> 8;
  108. response[1] = id;
  109. // Echo the query as the fake response.
  110. memcpy(response.data() + 2, query.base(), query.size());
  111. return response;
  112. }
  113. // Simplest possible fake server. This just echoes the query as the response.
  114. class FakeSocketEcho : public IDnsTlsSocket {
  115. public:
  116. explicit FakeSocketEcho(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
  117. bool query(uint16_t id, const Slice query) override {
  118. // Return the response immediately (asynchronously).
  119. std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query)).detach();
  120. return true;
  121. }
  122. private:
  123. IDnsTlsSocketObserver* const mObserver;
  124. };
  125. class TransportTest : public BaseTest {};
  126. TEST_F(TransportTest, Query) {
  127. FakeSocketFactory<FakeSocketEcho> factory;
  128. DnsTlsTransport transport(SERVER1, MARK, &factory);
  129. auto r = transport.query(makeSlice(QUERY)).get();
  130. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  131. EXPECT_EQ(QUERY, r.response);
  132. }
  133. // Fake Socket that echoes the observed query ID as the response body.
  134. class FakeSocketId : public IDnsTlsSocket {
  135. public:
  136. explicit FakeSocketId(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
  137. bool query(uint16_t id, const Slice query ATTRIBUTE_UNUSED) override {
  138. // Return the response immediately (asynchronously).
  139. bytevec response(4);
  140. // Echo the ID in the header to match the response to the query.
  141. // This will be overwritten by DnsTlsQueryMap.
  142. response[0] = id >> 8;
  143. response[1] = id;
  144. // Echo the ID in the body, so that the test can verify which ID was used by
  145. // DnsTlsQueryMap.
  146. response[2] = id >> 8;
  147. response[3] = id;
  148. std::thread(&IDnsTlsSocketObserver::onResponse, mObserver, response).detach();
  149. return true;
  150. }
  151. private:
  152. IDnsTlsSocketObserver* const mObserver;
  153. };
  154. // Test that IDs are properly reused
  155. TEST_F(TransportTest, IdReuse) {
  156. FakeSocketFactory<FakeSocketId> factory;
  157. DnsTlsTransport transport(SERVER1, MARK, &factory);
  158. for (int i = 0; i < 100; ++i) {
  159. // Send a query.
  160. std::future<DnsTlsServer::Result> f = transport.query(makeSlice(QUERY));
  161. // Wait for the response.
  162. DnsTlsServer::Result r = f.get();
  163. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  164. // All queries should have an observed ID of zero, because it is returned to the ID pool
  165. // after each use.
  166. EXPECT_EQ(0, (r.response[2] << 8) | r.response[3]);
  167. }
  168. }
  169. // These queries might be handled in serial or parallel as they race the
  170. // responses.
  171. TEST_F(TransportTest, RacingQueries_10000) {
  172. FakeSocketFactory<FakeSocketEcho> factory;
  173. DnsTlsTransport transport(SERVER1, MARK, &factory);
  174. std::vector<std::future<DnsTlsTransport::Result>> results;
  175. // Fewer than 65536 queries to avoid ID exhaustion.
  176. const int num_queries = 10000;
  177. results.reserve(num_queries);
  178. for (int i = 0; i < num_queries; ++i) {
  179. results.push_back(transport.query(makeSlice(QUERY)));
  180. }
  181. for (auto& result : results) {
  182. auto r = result.get();
  183. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  184. EXPECT_EQ(QUERY, r.response);
  185. }
  186. }
  187. // A server that waits until sDelay queries are queued before responding.
  188. class FakeSocketDelay : public IDnsTlsSocket {
  189. public:
  190. explicit FakeSocketDelay(IDnsTlsSocketObserver* observer) : mObserver(observer) {}
  191. ~FakeSocketDelay() { std::lock_guard guard(mLock); }
  192. static size_t sDelay;
  193. static bool sReverse;
  194. bool query(uint16_t id, const Slice query) override {
  195. ALOGV("FakeSocketDelay got query with ID %d", int(id));
  196. std::lock_guard guard(mLock);
  197. // Check for duplicate IDs.
  198. EXPECT_EQ(0U, mIds.count(id));
  199. mIds.insert(id);
  200. // Store response.
  201. mResponses.push_back(make_echo(id, query));
  202. ALOGV("Up to %zu out of %zu queries", mResponses.size(), sDelay);
  203. if (mResponses.size() == sDelay) {
  204. std::thread(&FakeSocketDelay::sendResponses, this).detach();
  205. }
  206. return true;
  207. }
  208. private:
  209. void sendResponses() {
  210. std::lock_guard guard(mLock);
  211. if (sReverse) {
  212. std::reverse(std::begin(mResponses), std::end(mResponses));
  213. }
  214. for (auto& response : mResponses) {
  215. mObserver->onResponse(response);
  216. }
  217. mIds.clear();
  218. mResponses.clear();
  219. }
  220. std::mutex mLock;
  221. IDnsTlsSocketObserver* const mObserver;
  222. std::set<uint16_t> mIds GUARDED_BY(mLock);
  223. std::vector<bytevec> mResponses GUARDED_BY(mLock);
  224. };
  225. size_t FakeSocketDelay::sDelay;
  226. bool FakeSocketDelay::sReverse;
  227. TEST_F(TransportTest, ParallelColliding) {
  228. FakeSocketDelay::sDelay = 10;
  229. FakeSocketDelay::sReverse = false;
  230. FakeSocketFactory<FakeSocketDelay> factory;
  231. DnsTlsTransport transport(SERVER1, MARK, &factory);
  232. std::vector<std::future<DnsTlsTransport::Result>> results;
  233. // Fewer than 65536 queries to avoid ID exhaustion.
  234. results.reserve(FakeSocketDelay::sDelay);
  235. for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
  236. results.push_back(transport.query(makeSlice(QUERY)));
  237. }
  238. for (auto& result : results) {
  239. auto r = result.get();
  240. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  241. EXPECT_EQ(QUERY, r.response);
  242. }
  243. }
  244. TEST_F(TransportTest, ParallelColliding_Max) {
  245. FakeSocketDelay::sDelay = 65536;
  246. FakeSocketDelay::sReverse = false;
  247. FakeSocketFactory<FakeSocketDelay> factory;
  248. DnsTlsTransport transport(SERVER1, MARK, &factory);
  249. std::vector<std::future<DnsTlsTransport::Result>> results;
  250. // Exactly 65536 queries should still be possible in parallel,
  251. // even if they all have the same original ID.
  252. results.reserve(FakeSocketDelay::sDelay);
  253. for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
  254. results.push_back(transport.query(makeSlice(QUERY)));
  255. }
  256. for (auto& result : results) {
  257. auto r = result.get();
  258. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  259. EXPECT_EQ(QUERY, r.response);
  260. }
  261. }
  262. TEST_F(TransportTest, ParallelUnique) {
  263. FakeSocketDelay::sDelay = 10;
  264. FakeSocketDelay::sReverse = false;
  265. FakeSocketFactory<FakeSocketDelay> factory;
  266. DnsTlsTransport transport(SERVER1, MARK, &factory);
  267. std::vector<bytevec> queries(FakeSocketDelay::sDelay);
  268. std::vector<std::future<DnsTlsTransport::Result>> results;
  269. results.reserve(FakeSocketDelay::sDelay);
  270. for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
  271. queries[i] = make_query(i, SIZE);
  272. results.push_back(transport.query(makeSlice(queries[i])));
  273. }
  274. for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
  275. auto r = results[i].get();
  276. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  277. EXPECT_EQ(queries[i], r.response);
  278. }
  279. }
  280. TEST_F(TransportTest, ParallelUnique_Max) {
  281. FakeSocketDelay::sDelay = 65536;
  282. FakeSocketDelay::sReverse = false;
  283. FakeSocketFactory<FakeSocketDelay> factory;
  284. DnsTlsTransport transport(SERVER1, MARK, &factory);
  285. std::vector<bytevec> queries(FakeSocketDelay::sDelay);
  286. std::vector<std::future<DnsTlsTransport::Result>> results;
  287. // Exactly 65536 queries should still be possible in parallel,
  288. // and they should all be mapped correctly back to the original ID.
  289. results.reserve(FakeSocketDelay::sDelay);
  290. for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
  291. queries[i] = make_query(i, SIZE);
  292. results.push_back(transport.query(makeSlice(queries[i])));
  293. }
  294. for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
  295. auto r = results[i].get();
  296. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  297. EXPECT_EQ(queries[i], r.response);
  298. }
  299. }
  300. TEST_F(TransportTest, IdExhaustion) {
  301. const int num_queries = 65536;
  302. // A delay of 65537 is unreachable, because the maximum number
  303. // of outstanding queries is 65536.
  304. FakeSocketDelay::sDelay = num_queries + 1;
  305. FakeSocketDelay::sReverse = false;
  306. FakeSocketFactory<FakeSocketDelay> factory;
  307. DnsTlsTransport transport(SERVER1, MARK, &factory);
  308. std::vector<std::future<DnsTlsTransport::Result>> results;
  309. // Issue the maximum number of queries.
  310. results.reserve(num_queries);
  311. for (int i = 0; i < num_queries; ++i) {
  312. results.push_back(transport.query(makeSlice(QUERY)));
  313. }
  314. // The ID space is now full, so subsequent queries should fail immediately.
  315. auto r = transport.query(makeSlice(QUERY)).get();
  316. EXPECT_EQ(DnsTlsTransport::Response::internal_error, r.code);
  317. EXPECT_TRUE(r.response.empty());
  318. for (auto& result : results) {
  319. // All other queries should remain outstanding.
  320. EXPECT_EQ(std::future_status::timeout,
  321. result.wait_for(std::chrono::duration<int>::zero()));
  322. }
  323. }
  324. // Responses can come back from the server in any order. This should have no
  325. // effect on Transport's observed behavior.
  326. TEST_F(TransportTest, ReverseOrder) {
  327. FakeSocketDelay::sDelay = 10;
  328. FakeSocketDelay::sReverse = true;
  329. FakeSocketFactory<FakeSocketDelay> factory;
  330. DnsTlsTransport transport(SERVER1, MARK, &factory);
  331. std::vector<bytevec> queries(FakeSocketDelay::sDelay);
  332. std::vector<std::future<DnsTlsTransport::Result>> results;
  333. results.reserve(FakeSocketDelay::sDelay);
  334. for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
  335. queries[i] = make_query(i, SIZE);
  336. results.push_back(transport.query(makeSlice(queries[i])));
  337. }
  338. for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
  339. auto r = results[i].get();
  340. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  341. EXPECT_EQ(queries[i], r.response);
  342. }
  343. }
  344. TEST_F(TransportTest, ReverseOrder_Max) {
  345. FakeSocketDelay::sDelay = 65536;
  346. FakeSocketDelay::sReverse = true;
  347. FakeSocketFactory<FakeSocketDelay> factory;
  348. DnsTlsTransport transport(SERVER1, MARK, &factory);
  349. std::vector<bytevec> queries(FakeSocketDelay::sDelay);
  350. std::vector<std::future<DnsTlsTransport::Result>> results;
  351. results.reserve(FakeSocketDelay::sDelay);
  352. for (size_t i = 0; i < FakeSocketDelay::sDelay; ++i) {
  353. queries[i] = make_query(i, SIZE);
  354. results.push_back(transport.query(makeSlice(queries[i])));
  355. }
  356. for (size_t i = 0 ; i < FakeSocketDelay::sDelay; ++i) {
  357. auto r = results[i].get();
  358. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  359. EXPECT_EQ(queries[i], r.response);
  360. }
  361. }
  362. // Returning null from the factory indicates a connection failure.
  363. class NullSocketFactory : public IDnsTlsSocketFactory {
  364. public:
  365. NullSocketFactory() {}
  366. std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
  367. const DnsTlsServer& server ATTRIBUTE_UNUSED,
  368. unsigned mark ATTRIBUTE_UNUSED,
  369. IDnsTlsSocketObserver* observer ATTRIBUTE_UNUSED,
  370. DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
  371. return nullptr;
  372. }
  373. };
  374. TEST_F(TransportTest, ConnectFail) {
  375. NullSocketFactory factory;
  376. DnsTlsTransport transport(SERVER1, MARK, &factory);
  377. auto r = transport.query(makeSlice(QUERY)).get();
  378. EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
  379. EXPECT_TRUE(r.response.empty());
  380. }
  381. // Simulate a socket that connects but then immediately receives a server
  382. // close notification.
  383. class FakeSocketClose : public IDnsTlsSocket {
  384. public:
  385. explicit FakeSocketClose(IDnsTlsSocketObserver* observer)
  386. : mCloser(&IDnsTlsSocketObserver::onClosed, observer) {}
  387. ~FakeSocketClose() { mCloser.join(); }
  388. bool query(uint16_t id ATTRIBUTE_UNUSED,
  389. const Slice query ATTRIBUTE_UNUSED) override {
  390. return true;
  391. }
  392. private:
  393. std::thread mCloser;
  394. };
  395. TEST_F(TransportTest, CloseRetryFail) {
  396. FakeSocketFactory<FakeSocketClose> factory;
  397. DnsTlsTransport transport(SERVER1, MARK, &factory);
  398. auto r = transport.query(makeSlice(QUERY)).get();
  399. EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
  400. EXPECT_TRUE(r.response.empty());
  401. }
  402. // Simulate a server that occasionally closes the connection and silently
  403. // drops some queries.
  404. class FakeSocketLimited : public IDnsTlsSocket {
  405. public:
  406. static int sLimit; // Number of queries to answer per socket.
  407. static size_t sMaxSize; // Silently discard queries greater than this size.
  408. explicit FakeSocketLimited(IDnsTlsSocketObserver* observer)
  409. : mObserver(observer), mQueries(0) {}
  410. ~FakeSocketLimited() {
  411. {
  412. ALOGV("~FakeSocketLimited acquiring mLock");
  413. std::lock_guard guard(mLock);
  414. ALOGV("~FakeSocketLimited acquired mLock");
  415. for (auto& thread : mThreads) {
  416. ALOGV("~FakeSocketLimited joining response thread");
  417. thread.join();
  418. ALOGV("~FakeSocketLimited joined response thread");
  419. }
  420. mThreads.clear();
  421. }
  422. if (mCloser) {
  423. ALOGV("~FakeSocketLimited joining closer thread");
  424. mCloser->join();
  425. ALOGV("~FakeSocketLimited joined closer thread");
  426. }
  427. }
  428. bool query(uint16_t id, const Slice query) override {
  429. ALOGV("FakeSocketLimited::query acquiring mLock");
  430. std::lock_guard guard(mLock);
  431. ALOGV("FakeSocketLimited::query acquired mLock");
  432. ++mQueries;
  433. if (mQueries <= sLimit) {
  434. ALOGV("size %zu vs. limit of %zu", query.size(), sMaxSize);
  435. if (query.size() <= sMaxSize) {
  436. // Return the response immediately (asynchronously).
  437. mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_echo(id, query));
  438. }
  439. }
  440. if (mQueries == sLimit) {
  441. mCloser = std::make_unique<std::thread>(&FakeSocketLimited::sendClose, this);
  442. }
  443. return mQueries <= sLimit;
  444. }
  445. private:
  446. void sendClose() {
  447. {
  448. ALOGV("FakeSocketLimited::sendClose acquiring mLock");
  449. std::lock_guard guard(mLock);
  450. ALOGV("FakeSocketLimited::sendClose acquired mLock");
  451. for (auto& thread : mThreads) {
  452. ALOGV("FakeSocketLimited::sendClose joining response thread");
  453. thread.join();
  454. ALOGV("FakeSocketLimited::sendClose joined response thread");
  455. }
  456. mThreads.clear();
  457. }
  458. mObserver->onClosed();
  459. }
  460. std::mutex mLock;
  461. IDnsTlsSocketObserver* const mObserver;
  462. int mQueries GUARDED_BY(mLock);
  463. std::vector<std::thread> mThreads GUARDED_BY(mLock);
  464. std::unique_ptr<std::thread> mCloser GUARDED_BY(mLock);
  465. };
  466. int FakeSocketLimited::sLimit;
  467. size_t FakeSocketLimited::sMaxSize;
  468. TEST_F(TransportTest, SilentDrop) {
  469. FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
  470. FakeSocketLimited::sMaxSize = 0; // Silently drop all queries
  471. FakeSocketFactory<FakeSocketLimited> factory;
  472. DnsTlsTransport transport(SERVER1, MARK, &factory);
  473. // Queue up 10 queries. They will all be ignored, and after the 10th,
  474. // the socket will close. Transport will retry them all, until they
  475. // all hit the retry limit and expire.
  476. std::vector<std::future<DnsTlsTransport::Result>> results;
  477. results.reserve(FakeSocketLimited::sLimit);
  478. for (int i = 0; i < FakeSocketLimited::sLimit; ++i) {
  479. results.push_back(transport.query(makeSlice(QUERY)));
  480. }
  481. for (auto& result : results) {
  482. auto r = result.get();
  483. EXPECT_EQ(DnsTlsTransport::Response::network_error, r.code);
  484. EXPECT_TRUE(r.response.empty());
  485. }
  486. }
  487. TEST_F(TransportTest, PartialDrop) {
  488. FakeSocketLimited::sLimit = 10; // Close the socket after 10 queries.
  489. FakeSocketLimited::sMaxSize = SIZE - 2; // Silently drop "long" queries
  490. FakeSocketFactory<FakeSocketLimited> factory;
  491. DnsTlsTransport transport(SERVER1, MARK, &factory);
  492. // Queue up 100 queries, alternating "short" which will be served and "long"
  493. // which will be dropped.
  494. const int num_queries = 10 * FakeSocketLimited::sLimit;
  495. std::vector<bytevec> queries(num_queries);
  496. std::vector<std::future<DnsTlsTransport::Result>> results;
  497. results.reserve(num_queries);
  498. for (int i = 0; i < num_queries; ++i) {
  499. queries[i] = make_query(i, SIZE + (i % 2));
  500. results.push_back(transport.query(makeSlice(queries[i])));
  501. }
  502. // Just check the short queries, which are at the even indices.
  503. for (int i = 0; i < num_queries; i += 2) {
  504. auto r = results[i].get();
  505. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  506. EXPECT_EQ(queries[i], r.response);
  507. }
  508. }
  509. // Simulate a malfunctioning server that injects extra miscellaneous
  510. // responses to queries that were not asked. This will cause wrong answers but
  511. // must not crash the Transport.
  512. class FakeSocketGarbage : public IDnsTlsSocket {
  513. public:
  514. explicit FakeSocketGarbage(IDnsTlsSocketObserver* observer) : mObserver(observer) {
  515. // Inject a garbage event.
  516. mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(ID + 1, SIZE));
  517. }
  518. ~FakeSocketGarbage() {
  519. std::lock_guard guard(mLock);
  520. for (auto& thread : mThreads) {
  521. thread.join();
  522. }
  523. }
  524. bool query(uint16_t id, const Slice query) override {
  525. std::lock_guard guard(mLock);
  526. // Return the response twice.
  527. auto echo = make_echo(id, query);
  528. mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
  529. mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, echo);
  530. // Also return some other garbage
  531. mThreads.emplace_back(&IDnsTlsSocketObserver::onResponse, mObserver, make_query(id + 1, query.size() + 2));
  532. return true;
  533. }
  534. private:
  535. std::mutex mLock;
  536. std::vector<std::thread> mThreads GUARDED_BY(mLock);
  537. IDnsTlsSocketObserver* const mObserver;
  538. };
  539. TEST_F(TransportTest, IgnoringGarbage) {
  540. FakeSocketFactory<FakeSocketGarbage> factory;
  541. DnsTlsTransport transport(SERVER1, MARK, &factory);
  542. for (int i = 0; i < 10; ++i) {
  543. auto r = transport.query(makeSlice(QUERY)).get();
  544. EXPECT_EQ(DnsTlsTransport::Response::success, r.code);
  545. // Don't check the response because this server is malfunctioning.
  546. }
  547. }
  548. // Dispatcher tests
  549. class DispatcherTest : public BaseTest {};
  550. TEST_F(DispatcherTest, Query) {
  551. bytevec ans(4096);
  552. int resplen = 0;
  553. auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
  554. DnsTlsDispatcher dispatcher(std::move(factory));
  555. auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
  556. makeSlice(ans), &resplen);
  557. EXPECT_EQ(DnsTlsTransport::Response::success, r);
  558. EXPECT_EQ(int(QUERY.size()), resplen);
  559. ans.resize(resplen);
  560. EXPECT_EQ(QUERY, ans);
  561. }
  562. TEST_F(DispatcherTest, AnswerTooLarge) {
  563. bytevec ans(SIZE - 1); // Too small to hold the answer
  564. int resplen = 0;
  565. auto factory = std::make_unique<FakeSocketFactory<FakeSocketEcho>>();
  566. DnsTlsDispatcher dispatcher(std::move(factory));
  567. auto r = dispatcher.query(SERVER1, MARK, makeSlice(QUERY),
  568. makeSlice(ans), &resplen);
  569. EXPECT_EQ(DnsTlsTransport::Response::limit_error, r);
  570. }
  571. template<class T>
  572. class TrackingFakeSocketFactory : public IDnsTlsSocketFactory {
  573. public:
  574. TrackingFakeSocketFactory() {}
  575. std::unique_ptr<IDnsTlsSocket> createDnsTlsSocket(
  576. const DnsTlsServer& server,
  577. unsigned mark,
  578. IDnsTlsSocketObserver* observer,
  579. DnsTlsSessionCache* cache ATTRIBUTE_UNUSED) override {
  580. std::lock_guard guard(mLock);
  581. keys.emplace(mark, server);
  582. return std::make_unique<T>(observer);
  583. }
  584. std::multiset<std::pair<unsigned, DnsTlsServer>> keys;
  585. private:
  586. std::mutex mLock;
  587. };
  588. TEST_F(DispatcherTest, Dispatching) {
  589. FakeSocketDelay::sDelay = 5;
  590. FakeSocketDelay::sReverse = true;
  591. auto factory = std::make_unique<TrackingFakeSocketFactory<FakeSocketDelay>>();
  592. auto* weak_factory = factory.get(); // Valid as long as dispatcher is in scope.
  593. DnsTlsDispatcher dispatcher(std::move(factory));
  594. // Populate a vector of two servers and two socket marks, four combinations
  595. // in total.
  596. std::vector<std::pair<unsigned, DnsTlsServer>> keys;
  597. keys.emplace_back(MARK, SERVER1);
  598. keys.emplace_back(MARK + 1, SERVER1);
  599. keys.emplace_back(MARK, V4ADDR2);
  600. keys.emplace_back(MARK + 1, V4ADDR2);
  601. // Do several queries on each server. They should all succeed.
  602. std::vector<std::thread> threads;
  603. for (size_t i = 0; i < FakeSocketDelay::sDelay * keys.size(); ++i) {
  604. auto key = keys[i % keys.size()];
  605. threads.emplace_back([key, i] (DnsTlsDispatcher* dispatcher) {
  606. auto q = make_query(i, SIZE);
  607. bytevec ans(4096);
  608. int resplen = 0;
  609. unsigned mark = key.first;
  610. const DnsTlsServer& server = key.second;
  611. auto r = dispatcher->query(server, mark, makeSlice(q),
  612. makeSlice(ans), &resplen);
  613. EXPECT_EQ(DnsTlsTransport::Response::success, r);
  614. EXPECT_EQ(int(q.size()), resplen);
  615. ans.resize(resplen);
  616. EXPECT_EQ(q, ans);
  617. }, &dispatcher);
  618. }
  619. for (auto& thread : threads) {
  620. thread.join();
  621. }
  622. // We expect that the factory created one socket for each key.
  623. EXPECT_EQ(keys.size(), weak_factory->keys.size());
  624. for (auto& key : keys) {
  625. EXPECT_EQ(1U, weak_factory->keys.count(key));
  626. }
  627. }
  628. // Check DnsTlsServer's comparison logic.
  629. AddressComparator ADDRESS_COMPARATOR;
  630. bool isAddressEqual(const DnsTlsServer& s1, const DnsTlsServer& s2) {
  631. bool cmp1 = ADDRESS_COMPARATOR(s1, s2);
  632. bool cmp2 = ADDRESS_COMPARATOR(s2, s1);
  633. EXPECT_FALSE(cmp1 && cmp2);
  634. return !cmp1 && !cmp2;
  635. }
  636. void checkUnequal(const DnsTlsServer& s1, const DnsTlsServer& s2) {
  637. EXPECT_TRUE(s1 == s1);
  638. EXPECT_TRUE(s2 == s2);
  639. EXPECT_TRUE(isAddressEqual(s1, s1));
  640. EXPECT_TRUE(isAddressEqual(s2, s2));
  641. EXPECT_TRUE(s1 < s2 ^ s2 < s1);
  642. EXPECT_FALSE(s1 == s2);
  643. EXPECT_FALSE(s2 == s1);
  644. }
  645. class ServerTest : public BaseTest {};
  646. TEST_F(ServerTest, IPv4) {
  647. checkUnequal(V4ADDR1, V4ADDR2);
  648. EXPECT_FALSE(isAddressEqual(V4ADDR1, V4ADDR2));
  649. }
  650. TEST_F(ServerTest, IPv6) {
  651. checkUnequal(V6ADDR1, V6ADDR2);
  652. EXPECT_FALSE(isAddressEqual(V6ADDR1, V6ADDR2));
  653. }
  654. TEST_F(ServerTest, MixedAddressFamily) {
  655. checkUnequal(V6ADDR1, V4ADDR1);
  656. EXPECT_FALSE(isAddressEqual(V6ADDR1, V4ADDR1));
  657. }
  658. TEST_F(ServerTest, IPv6ScopeId) {
  659. DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
  660. sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
  661. addr1->sin6_scope_id = 1;
  662. sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
  663. addr2->sin6_scope_id = 2;
  664. checkUnequal(s1, s2);
  665. EXPECT_FALSE(isAddressEqual(s1, s2));
  666. EXPECT_FALSE(s1.wasExplicitlyConfigured());
  667. EXPECT_FALSE(s2.wasExplicitlyConfigured());
  668. }
  669. TEST_F(ServerTest, IPv6FlowInfo) {
  670. DnsTlsServer s1(V6ADDR1), s2(V6ADDR1);
  671. sockaddr_in6* addr1 = reinterpret_cast<sockaddr_in6*>(&s1.ss);
  672. addr1->sin6_flowinfo = 1;
  673. sockaddr_in6* addr2 = reinterpret_cast<sockaddr_in6*>(&s2.ss);
  674. addr2->sin6_flowinfo = 2;
  675. // All comparisons ignore flowinfo.
  676. EXPECT_EQ(s1, s2);
  677. EXPECT_TRUE(isAddressEqual(s1, s2));
  678. EXPECT_FALSE(s1.wasExplicitlyConfigured());
  679. EXPECT_FALSE(s2.wasExplicitlyConfigured());
  680. }
  681. TEST_F(ServerTest, Port) {
  682. DnsTlsServer s1, s2;
  683. parseServer("192.0.2.1", 853, &s1.ss);
  684. parseServer("192.0.2.1", 854, &s2.ss);
  685. checkUnequal(s1, s2);
  686. EXPECT_TRUE(isAddressEqual(s1, s2));
  687. DnsTlsServer s3, s4;
  688. parseServer("2001:db8::1", 853, &s3.ss);
  689. parseServer("2001:db8::1", 852, &s4.ss);
  690. checkUnequal(s3, s4);
  691. EXPECT_TRUE(isAddressEqual(s3, s4));
  692. EXPECT_FALSE(s1.wasExplicitlyConfigured());
  693. EXPECT_FALSE(s2.wasExplicitlyConfigured());
  694. }
  695. TEST_F(ServerTest, Name) {
  696. DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
  697. s1.name = SERVERNAME1;
  698. checkUnequal(s1, s2);
  699. s2.name = SERVERNAME2;
  700. checkUnequal(s1, s2);
  701. EXPECT_TRUE(isAddressEqual(s1, s2));
  702. EXPECT_TRUE(s1.wasExplicitlyConfigured());
  703. EXPECT_TRUE(s2.wasExplicitlyConfigured());
  704. }
  705. TEST_F(ServerTest, Fingerprint) {
  706. DnsTlsServer s1(V4ADDR1), s2(V4ADDR1);
  707. s1.fingerprints.insert(FINGERPRINT1);
  708. checkUnequal(s1, s2);
  709. EXPECT_TRUE(isAddressEqual(s1, s2));
  710. s2.fingerprints.insert(FINGERPRINT2);
  711. checkUnequal(s1, s2);
  712. EXPECT_TRUE(isAddressEqual(s1, s2));
  713. s2.fingerprints.insert(FINGERPRINT1);
  714. checkUnequal(s1, s2);
  715. EXPECT_TRUE(isAddressEqual(s1, s2));
  716. s1.fingerprints.insert(FINGERPRINT2);
  717. EXPECT_EQ(s1, s2);
  718. EXPECT_TRUE(isAddressEqual(s1, s2));
  719. EXPECT_TRUE(s1.wasExplicitlyConfigured());
  720. EXPECT_TRUE(s2.wasExplicitlyConfigured());
  721. }
  722. TEST(QueryMapTest, Basic) {
  723. DnsTlsQueryMap map;
  724. EXPECT_TRUE(map.empty());
  725. bytevec q0 = make_query(999, SIZE);
  726. bytevec q1 = make_query(888, SIZE);
  727. bytevec q2 = make_query(777, SIZE);
  728. auto f0 = map.recordQuery(makeSlice(q0));
  729. auto f1 = map.recordQuery(makeSlice(q1));
  730. auto f2 = map.recordQuery(makeSlice(q2));
  731. // Check return values of recordQuery
  732. EXPECT_EQ(0, f0->query.newId);
  733. EXPECT_EQ(1, f1->query.newId);
  734. EXPECT_EQ(2, f2->query.newId);
  735. // Check side effects of recordQuery
  736. EXPECT_FALSE(map.empty());
  737. auto all = map.getAll();
  738. EXPECT_EQ(3U, all.size());
  739. EXPECT_EQ(0, all[0].newId);
  740. EXPECT_EQ(1, all[1].newId);
  741. EXPECT_EQ(2, all[2].newId);
  742. EXPECT_EQ(makeSlice(q0), all[0].query);
  743. EXPECT_EQ(makeSlice(q1), all[1].query);
  744. EXPECT_EQ(makeSlice(q2), all[2].query);
  745. bytevec a0 = make_query(0, SIZE);
  746. bytevec a1 = make_query(1, SIZE);
  747. bytevec a2 = make_query(2, SIZE);
  748. // Return responses out of order
  749. map.onResponse(a2);
  750. map.onResponse(a0);
  751. map.onResponse(a1);
  752. EXPECT_TRUE(map.empty());
  753. auto r0 = f0->result.get();
  754. auto r1 = f1->result.get();
  755. auto r2 = f2->result.get();
  756. EXPECT_EQ(DnsTlsQueryMap::Response::success, r0.code);
  757. EXPECT_EQ(DnsTlsQueryMap::Response::success, r1.code);
  758. EXPECT_EQ(DnsTlsQueryMap::Response::success, r2.code);
  759. const bytevec& d0 = r0.response;
  760. const bytevec& d1 = r1.response;
  761. const bytevec& d2 = r2.response;
  762. // The ID should match the query
  763. EXPECT_EQ(999, d0[0] << 8 | d0[1]);
  764. EXPECT_EQ(888, d1[0] << 8 | d1[1]);
  765. EXPECT_EQ(777, d2[0] << 8 | d2[1]);
  766. // The body should match the answer
  767. EXPECT_EQ(bytevec(a0.begin() + 2, a0.end()), bytevec(d0.begin() + 2, d0.end()));
  768. EXPECT_EQ(bytevec(a1.begin() + 2, a1.end()), bytevec(d1.begin() + 2, d1.end()));
  769. EXPECT_EQ(bytevec(a2.begin() + 2, a2.end()), bytevec(d2.begin() + 2, d2.end()));
  770. }
  771. TEST(QueryMapTest, FillHole) {
  772. DnsTlsQueryMap map;
  773. std::vector<std::unique_ptr<DnsTlsQueryMap::QueryFuture>> futures(UINT16_MAX + 1);
  774. for (uint32_t i = 0; i <= UINT16_MAX; ++i) {
  775. futures[i] = map.recordQuery(makeSlice(QUERY));
  776. ASSERT_TRUE(futures[i]); // answers[i] should be nonnull.
  777. EXPECT_EQ(i, futures[i]->query.newId);
  778. }
  779. // The map should now be full.
  780. EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
  781. // Trying to add another query should fail because the map is full.
  782. EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
  783. // Send an answer to query 40000
  784. auto answer = make_query(40000, SIZE);
  785. map.onResponse(answer);
  786. auto result = futures[40000]->result.get();
  787. EXPECT_EQ(DnsTlsQueryMap::Response::success, result.code);
  788. EXPECT_EQ(ID, result.response[0] << 8 | result.response[1]);
  789. EXPECT_EQ(bytevec(answer.begin() + 2, answer.end()),
  790. bytevec(result.response.begin() + 2, result.response.end()));
  791. // There should now be room in the map.
  792. EXPECT_EQ(size_t(UINT16_MAX), map.getAll().size());
  793. auto f = map.recordQuery(makeSlice(QUERY));
  794. ASSERT_TRUE(f);
  795. EXPECT_EQ(40000, f->query.newId);
  796. // The map should now be full again.
  797. EXPECT_EQ(size_t(UINT16_MAX + 1), map.getAll().size());
  798. EXPECT_FALSE(map.recordQuery(makeSlice(QUERY)));
  799. }
  800. class StubObserver : public IDnsTlsSocketObserver {
  801. public:
  802. bool closed = false;
  803. void onResponse(std::vector<uint8_t>) override {}
  804. void onClosed() override { closed = true; }
  805. };
  806. TEST(DnsTlsSocketTest, SlowDestructor) {
  807. constexpr char tls_addr[] = "127.0.0.3";
  808. constexpr char tls_port[] = "8530"; // High-numbered port so root isn't required.
  809. // This test doesn't perform any queries, so the backend address can be invalid.
  810. constexpr char backend_addr[] = "192.0.2.1";
  811. constexpr char backend_port[] = "1";
  812. test::DnsTlsFrontend tls(tls_addr, tls_port, backend_addr, backend_port);
  813. ASSERT_TRUE(tls.startServer());
  814. DnsTlsServer server;
  815. parseServer(tls_addr, 8530, &server.ss);
  816. StubObserver observer;
  817. ASSERT_FALSE(observer.closed);
  818. DnsTlsSessionCache cache;
  819. auto socket = std::make_unique<DnsTlsSocket>(server, MARK, &observer, &cache);
  820. ASSERT_TRUE(socket->initialize());
  821. // Test: Time the socket destructor. This should be fast.
  822. auto before = std::chrono::steady_clock::now();
  823. socket.reset();
  824. auto after = std::chrono::steady_clock::now();
  825. auto delay = after - before;
  826. ALOGV("Shutdown took %lld ns", delay / std::chrono::nanoseconds{1});
  827. EXPECT_TRUE(observer.closed);
  828. // Shutdown should complete in milliseconds, but if the shutdown signal is lost
  829. // it will wait for the timeout, which is expected to take 20seconds.
  830. EXPECT_LT(delay, std::chrono::seconds{5});
  831. }
  832. } // end of namespace net
  833. } // end of namespace android