From 27c378a88af4e4db6f26b66438040da598f11b82 Mon Sep 17 00:00:00 2001 From: nsubiron Date: Mon, 8 Oct 2018 23:02:51 +0200 Subject: [PATCH] Fix issues related to unsubscribing from a stream --- .../carla/streaming/detail/tcp/Client.cpp | 58 +++++++------ .../carla/streaming/detail/tcp/Client.h | 14 +-- .../streaming/detail/tcp/ServerSession.cpp | 28 ++++-- .../streaming/detail/tcp/ServerSession.h | 2 + .../source/carla/streaming/low_level/Client.h | 26 ++++-- .../source/test/test_streaming_low_level.cpp | 86 ++++++++++++++++--- .../test/test_streaming_low_level_tcp.cpp | 4 +- 7 files changed, 158 insertions(+), 60 deletions(-) diff --git a/LibCarla/source/carla/streaming/detail/tcp/Client.cpp b/LibCarla/source/carla/streaming/detail/tcp/Client.cpp index 88ebbb91b..b0113f22b 100644 --- a/LibCarla/source/carla/streaming/detail/tcp/Client.cpp +++ b/LibCarla/source/carla/streaming/detail/tcp/Client.cpp @@ -75,27 +75,13 @@ namespace tcp { if (!_token.protocol_is_tcp()) { throw std::invalid_argument("invalid token, only TCP tokens supported"); } - Connect(); } - Client::~Client() { - _done = true; - /// @todo Destroying this client is not safe, another thread might be still - /// @using it. - } - - void Client::Stop() { - _done = true; - _connection_timer.cancel(); - _strand.post([this]() { - if (_socket.is_open()) { - _socket.close(); - } - }); - } + Client::~Client() = default; void Client::Connect() { - _strand.post([this]() { + auto self = shared_from_this(); + _strand.post([this, self]() { if (_done) { return; } @@ -110,8 +96,11 @@ namespace tcp { DEBUG_ASSERT(_token.protocol_is_tcp()); const auto ep = _token.to_tcp_endpoint(); - auto handle_connect = [=](error_code ec) { + auto handle_connect = [this, self, ep](error_code ec) { if (!ec) { + if (_done) { + return; + } log_debug("streaming client: connected to", ep); // Send the stream id to subscribe to the stream. const auto &stream_id = _token.get_stream_id(); @@ -126,12 +115,12 @@ namespace tcp { ReadData(); } else { // Else try again. - log_warning("streaming client: failed to send stream id:", ec.message()); + log_info("streaming client: failed to send stream id:", ec.message()); Connect(); } })); } else { - log_warning("streaming client: connection failed:", ec.message()); + log_info("streaming client: connection failed:", ec.message()); Reconnect(); } }; @@ -141,9 +130,21 @@ namespace tcp { }); } + void Client::Stop() { + _connection_timer.cancel(); + auto self = shared_from_this(); + _strand.post([this, self]() { + _done = true; + if (_socket.is_open()) { + _socket.close(); + } + }); + } + void Client::Reconnect() { + auto self = shared_from_this(); _connection_timer.expires_from_now(time_duration::seconds(1u)); - _connection_timer.async_wait([this](boost::system::error_code ec) { + _connection_timer.async_wait([this, self](boost::system::error_code ec) { if (!ec) { Connect(); } @@ -151,7 +152,8 @@ namespace tcp { } void Client::ReadData() { - _strand.post([this]() { + auto self = shared_from_this(); + _strand.post([this, self]() { if (_done) { return; } @@ -160,7 +162,7 @@ namespace tcp { auto message = std::make_shared(_buffer_pool->Pop()); - auto handle_read_data = [=](boost::system::error_code ec, size_t DEBUG_ONLY(bytes)) { + auto handle_read_data = [this, self, message](boost::system::error_code ec, size_t DEBUG_ONLY(bytes)) { DEBUG_ONLY(log_debug("streaming client: Client::ReadData.handle_read_data", bytes, "bytes")); if (!ec) { DEBUG_ASSERT_EQ(bytes, message->size()); @@ -168,16 +170,18 @@ namespace tcp { // Move the buffer to the callback function and start reading the next // piece of data. log_debug("streaming client: success reading data, calling the callback"); - _socket.get_io_service().post([this, message]() { _callback(message->pop()); }); + _socket.get_io_service().post([self, message]() { self->_callback(message->pop()); }); ReadData(); } else { // As usual, if anything fails start over from the very top. - log_warning("streaming client: failed to read data:", ec.message()); + log_info("streaming client: failed to read data:", ec.message()); Connect(); } }; - auto handle_read_header = [=](boost::system::error_code ec, size_t DEBUG_ONLY(bytes)) { + auto handle_read_header = [this, self, message, handle_read_data]( + boost::system::error_code ec, + size_t DEBUG_ONLY(bytes)) { DEBUG_ONLY(log_debug("streaming client: Client::ReadData.handle_read_header", bytes, "bytes")); if (!ec && (message->size() > 0u)) { DEBUG_ASSERT_EQ(bytes, sizeof(message_size_type)); @@ -191,7 +195,7 @@ namespace tcp { message->buffer(), _strand.wrap(handle_read_data)); } else { - log_warning("streaming client: failed to read header:", ec.message()); + log_info("streaming client: failed to read header:", ec.message()); DEBUG_ONLY(log_debug("size = ", message->size())); DEBUG_ONLY(log_debug("bytes = ", bytes)); Connect(); diff --git a/LibCarla/source/carla/streaming/detail/tcp/Client.h b/LibCarla/source/carla/streaming/detail/tcp/Client.h index e11b48053..047f16f89 100644 --- a/LibCarla/source/carla/streaming/detail/tcp/Client.h +++ b/LibCarla/source/carla/streaming/detail/tcp/Client.h @@ -30,9 +30,11 @@ namespace tcp { /// A client that connects to a single stream. /// - /// @warning The client should not be destroyed before the @a io_service is - /// stopped. - class Client : private NonCopyable { + /// @warning This client should be stopped before releasing the shared pointer + /// or won't be destroyed. + class Client + : public std::enable_shared_from_this, + private NonCopyable { public: using endpoint = boost::asio::ip::tcp::endpoint; @@ -46,15 +48,15 @@ namespace tcp { ~Client(); + void Connect(); + stream_id_type GetStreamId() const { return _token.get_stream_id(); } - private: - void Stop(); - void Connect(); + private: void Reconnect(); diff --git a/LibCarla/source/carla/streaming/detail/tcp/ServerSession.cpp b/LibCarla/source/carla/streaming/detail/tcp/ServerSession.cpp index 9cd4598d1..7c8fc4432 100644 --- a/LibCarla/source/carla/streaming/detail/tcp/ServerSession.cpp +++ b/LibCarla/source/carla/streaming/detail/tcp/ServerSession.cpp @@ -48,7 +48,7 @@ namespace tcp { _socket.get_io_service().post([=]() { cb(self); }); } else { log_error("session", _session_id, ": error retrieving stream id :", ec.message()); - Close(); + CloseNow(); } }; @@ -63,10 +63,7 @@ namespace tcp { void ServerSession::Close() { _strand.post([this, self = shared_from_this()]() { - if (_socket.is_open()) { - _socket.close(); - } - log_debug("session", _session_id, "closed"); + CloseNow(); }); } @@ -75,6 +72,9 @@ namespace tcp { DEBUG_ASSERT(!message->empty()); auto self = shared_from_this(); _strand.post([=]() { + if (!_socket.is_open()) { + return; + } if (_is_writing) { log_debug("session", _session_id, ": connection too slow: message discarded"); return; @@ -84,7 +84,8 @@ namespace tcp { auto handle_sent = [this, self, message](const boost::system::error_code &ec, size_t DEBUG_ONLY(bytes)) { _is_writing = false; if (ec) { - log_error("session", _session_id, ": error sending data :", ec.message()); + log_info("session", _session_id, ": error sending data :", ec.message()); + CloseNow(); } else { DEBUG_ONLY(log_debug("session", _session_id, ": successfully sent", bytes, "bytes")); DEBUG_ASSERT_EQ(bytes, sizeof(message_size_type) + message->size()); @@ -106,12 +107,23 @@ namespace tcp { log_debug("session", _session_id, "timed out"); Close(); } else { - _deadline.async_wait([self = shared_from_this()](boost::system::error_code) { - self->StartTimer(); + std::weak_ptr weak_self = shared_from_this(); + _deadline.async_wait([weak_self](boost::system::error_code) { + auto self = weak_self.lock(); + if (self != nullptr) { + self->StartTimer(); + } }); } } + void ServerSession::CloseNow() { + if (_socket.is_open()) { + _socket.close(); + } + log_debug("session", _session_id, "closed"); + } + } // namespace tcp } // namespace detail } // namespace streaming diff --git a/LibCarla/source/carla/streaming/detail/tcp/ServerSession.h b/LibCarla/source/carla/streaming/detail/tcp/ServerSession.h index 0b4105542..656286502 100644 --- a/LibCarla/source/carla/streaming/detail/tcp/ServerSession.h +++ b/LibCarla/source/carla/streaming/detail/tcp/ServerSession.h @@ -68,6 +68,8 @@ namespace tcp { void StartTimer(); + void CloseNow(); + friend class Server; const size_t _session_id; diff --git a/LibCarla/source/carla/streaming/low_level/Client.h b/LibCarla/source/carla/streaming/low_level/Client.h index e25055b23..71a7ad94b 100644 --- a/LibCarla/source/carla/streaming/low_level/Client.h +++ b/LibCarla/source/carla/streaming/low_level/Client.h @@ -11,6 +11,7 @@ #include +#include #include namespace carla { @@ -39,6 +40,12 @@ namespace low_level { explicit Client() : Client(carla::streaming::make_localhost_address()) {} + ~Client() { + for (auto &pair : _clients) { + pair.second->Stop(); + } + } + template void Subscribe( boost::asio::io_service &io_service, @@ -47,20 +54,29 @@ namespace low_level { if (!token.has_address()) { token.set_address(_fallback_address); } - _clients.emplace(std::piecewise_construct, - std::forward_as_tuple(token.get_stream_id()), - std::forward_as_tuple(io_service, token, std::forward(callback))); + auto client = std::make_shared( + io_service, + token, + std::forward(callback)); + client->Connect(); + _clients.emplace(token.get_stream_id(), std::move(client)); } void UnSubscribe(token_type token) { - _clients.erase(token.get_stream_id()); + auto it = _clients.find(token.get_stream_id()); + if (it != _clients.end()) { + it->second->Stop(); + _clients.erase(it); + } } private: boost::asio::ip::address _fallback_address; - std::unordered_map _clients; + std::unordered_map< + detail::stream_id_type, + std::shared_ptr> _clients; }; } // namespace low_level diff --git a/LibCarla/source/test/test_streaming_low_level.cpp b/LibCarla/source/test/test_streaming_low_level.cpp index fb391adb9..e69ec83f3 100644 --- a/LibCarla/source/test/test_streaming_low_level.cpp +++ b/LibCarla/source/test/test_streaming_low_level.cpp @@ -14,43 +14,103 @@ #include +// This is required for low level to properly stop the threads in case of +// exception/assert. +class io_service_running { +public: + + boost::asio::io_service service; + + explicit io_service_running(size_t threads = 2u) + : _work_to_do(service) { + _threads.CreateThreads(threads, [this]() { service.run(); }); + } + + ~io_service_running() { + service.stop(); + } + +private: + + boost::asio::io_service::work _work_to_do; + + carla::ThreadGroup _threads; +}; + TEST(streaming_low_level, sending_strings) { using namespace util::buffer; using namespace carla::streaming; using namespace carla::streaming::detail; using namespace carla::streaming::low_level; - constexpr auto number_of_messages = 5'000u; + constexpr auto number_of_messages = 100u; const std::string message_text = "Hello client!"; std::atomic_size_t message_count{0u}; - boost::asio::io_service io_service; + io_service_running io; - Server srv(io_service, TESTING_PORT); + Server srv(io.service, TESTING_PORT); srv.SetTimeout(1s); auto stream = srv.MakeStream(); Client c; - c.Subscribe(io_service, stream.token(), [&](auto message) { + c.Subscribe(io.service, stream.token(), [&](auto message) { ++message_count; ASSERT_EQ(message.size(), message_text.size()); const std::string msg = as_string(message); ASSERT_EQ(msg, message_text); }); - carla::ThreadGroup threads; - threads.CreateThreads( - std::max(2u, std::thread::hardware_concurrency()), - [&]() { io_service.run(); }); - for (auto i = 0u; i < number_of_messages; ++i) { + std::this_thread::sleep_for(2ms); stream << message_text; } - std::this_thread::sleep_for(1s); - io_service.stop(); - - std::cout << "client received " << message_count << " messages\n"; + std::this_thread::sleep_for(2ms); + ASSERT_EQ(message_count, number_of_messages); +} + +TEST(streaming_low_level, unsubscribing) { + using namespace util::buffer; + using namespace carla::streaming; + using namespace carla::streaming::detail; + using namespace carla::streaming::low_level; + + constexpr auto number_of_messages = 50u; + const std::string message_text = "Hello client!"; + + io_service_running io; + + Server srv(io.service, TESTING_PORT); + srv.SetTimeout(1s); + + Client c; + for (auto n = 0u; n < 10u; ++n) { + auto stream = srv.MakeStream(); + std::atomic_size_t message_count{0u}; + + c.Subscribe(io.service, stream.token(), [&](auto message) { + ++message_count; + ASSERT_EQ(message.size(), message_text.size()); + const std::string msg = as_string(message); + ASSERT_EQ(msg, message_text); + }); + + for (auto i = 0u; i < number_of_messages; ++i) { + std::this_thread::sleep_for(2ms); + stream << message_text; + } + + std::this_thread::sleep_for(2ms); + c.UnSubscribe(stream.token()); + + for (auto i = 0u; i < number_of_messages; ++i) { + std::this_thread::sleep_for(2ms); + stream << message_text; + } + + ASSERT_EQ(message_count, number_of_messages); + } } diff --git a/LibCarla/source/test/test_streaming_low_level_tcp.cpp b/LibCarla/source/test/test_streaming_low_level_tcp.cpp index 7576caa0a..dfb7873b7 100644 --- a/LibCarla/source/test/test_streaming_low_level_tcp.cpp +++ b/LibCarla/source/test/test_streaming_low_level_tcp.cpp @@ -38,13 +38,14 @@ TEST(streaming_low_level_tcp, small_message) { Dispatcher dispatcher{make_endpoint(ep)}; auto stream = dispatcher.MakeStream(); - tcp::Client c(io_service, stream.token(), [&](carla::Buffer message) { + auto c = std::make_shared(io_service, stream.token(), [&](carla::Buffer message) { ++message_count; ASSERT_FALSE(message.empty()); ASSERT_EQ(message.size(), 5u); const std::string received = util::buffer::as_string(message); ASSERT_EQ(received, msg); }); + c->Connect(); // We need at least two threads because this server loop consumes one. carla::ThreadGroup threads; @@ -57,4 +58,5 @@ TEST(streaming_low_level_tcp, small_message) { done = true; std::cout << "client received " << message_count << " messages\n"; ASSERT_GT(message_count, 10u); + c->Stop(); }