diff --git a/libappfuse/FuseBuffer.cc b/libappfuse/FuseBuffer.cc index 3ade31c54..882d54552 100644 --- a/libappfuse/FuseBuffer.cc +++ b/libappfuse/FuseBuffer.cc @@ -34,26 +34,38 @@ static_assert( "FuseBuffer must be standard layout union."); template -bool FuseMessage::CheckHeaderLength() const { +bool FuseMessage::CheckPacketSize(size_t size, const char* name) const { const auto& header = static_cast(this)->header; - if (sizeof(header) <= header.len && header.len <= sizeof(T)) { + if (size >= sizeof(header) && size <= sizeof(T)) { return true; } else { - LOG(ERROR) << "Packet size is invalid=" << header.len; + LOG(ERROR) << name << " is invalid=" << size; return false; } } template -bool FuseMessage::CheckResult( - int result, const char* operation_name) const { +bool FuseMessage::CheckResult(int result, const char* operation_name) const { + if (result == 0) { + // Expected close of other endpoints. + return false; + } + if (result < 0) { + PLOG(ERROR) << "Failed to " << operation_name << " a packet"; + return false; + } + return true; +} + +template +bool FuseMessage::CheckHeaderLength(int result, const char* operation_name) const { const auto& header = static_cast(this)->header; - if (result >= 0 && static_cast(result) == header.len) { + if (static_cast(result) == header.len) { return true; } else { - PLOG(ERROR) << "Failed to " << operation_name - << " a packet. result=" << result << " header.len=" - << header.len; + LOG(ERROR) << "Invalid header length: operation_name=" << operation_name + << " result=" << result + << " header.len=" << header.len; return false; } } @@ -61,17 +73,18 @@ bool FuseMessage::CheckResult( template bool FuseMessage::Read(int fd) { const ssize_t result = TEMP_FAILURE_RETRY(::read(fd, this, sizeof(T))); - return CheckHeaderLength() && CheckResult(result, "read"); + return CheckResult(result, "read") && CheckPacketSize(result, "read count") && + CheckHeaderLength(result, "read"); } template bool FuseMessage::Write(int fd) const { const auto& header = static_cast(this)->header; - if (!CheckHeaderLength()) { + if (!CheckPacketSize(header.len, "header.len")) { return false; } const ssize_t result = TEMP_FAILURE_RETRY(::write(fd, this, header.len)); - return CheckResult(result, "write"); + return CheckResult(result, "write") && CheckHeaderLength(result, "write"); } template class FuseMessage; diff --git a/libappfuse/include/libappfuse/FuseBuffer.h b/libappfuse/include/libappfuse/FuseBuffer.h index e7f620cb6..276db9020 100644 --- a/libappfuse/include/libappfuse/FuseBuffer.h +++ b/libappfuse/include/libappfuse/FuseBuffer.h @@ -34,8 +34,9 @@ class FuseMessage { bool Read(int fd); bool Write(int fd) const; private: - bool CheckHeaderLength() const; + bool CheckPacketSize(size_t size, const char* name) const; bool CheckResult(int result, const char* operation_name) const; + bool CheckHeaderLength(int result, const char* operation_name) const; }; // FuseRequest represents file operation requests from /dev/fuse. It starts