viewer: Simplify tunnel handling

- Cache ginfo, since it is tied to a single viewer instance
- Always create the sshtunnels object
- Track lock state so we can make it idempotent
This commit is contained in:
Cole Robinson 2015-04-11 21:39:24 -04:00
parent 93f826a8b2
commit 285e345f17
3 changed files with 74 additions and 77 deletions

View File

@ -709,12 +709,12 @@ class vmmConsolePages(vmmGObjectUI):
elif ginfo.gtype == "spice":
viewer_class = SpiceViewer
self._viewer = viewer_class()
self._viewer = viewer_class(ginfo)
self._connect_viewer_signals()
self._refresh_enable_accel()
self._viewer.console_open_ginfo(ginfo)
self._viewer.console_open()
except Exception, e:
logging.exception("Error connection to graphical console")
self._activate_unavailable_page(

View File

@ -104,23 +104,25 @@ class _TunnelScheduler(object):
independent of connection, vm, etc.
"""
def __init__(self):
self._thread = threading.Thread(name="Tunnel thread",
target=self._handle_queue,
args=())
self._thread.daemon = True
self._thread = None
self._queue = Queue.Queue()
self._lock = threading.Lock()
def _handle_queue(self):
while True:
cb, args, = self._queue.get()
self.lock()
lock_cb, cb, args, = self._queue.get()
lock_cb()
vmmGObject.idle_add(cb, *args)
def schedule(self, cb, *args):
def schedule(self, lock_cb, cb, *args):
if not self._thread:
self._thread = threading.Thread(name="Tunnel thread",
target=self._handle_queue,
args=())
self._thread.daemon = True
if not self._thread.is_alive():
self._thread.start()
self._queue.put((cb, args))
self._queue.put((lock_cb, cb, args))
def lock(self):
self._lock.acquire()
@ -128,14 +130,17 @@ class _TunnelScheduler(object):
self._lock.release()
_tunnel_scheduler = _TunnelScheduler()
class _Tunnel(object):
def __init__(self):
self.outfd = None
self.errfd = None
self.pid = None
self._outfd = None
self._errfd = None
self._pid = None
self._outfds = None
self._errfds = None
self.closed = False
self._closed = False
def open(self, ginfo):
self._outfds = socket.socketpair()
@ -144,38 +149,38 @@ class _Tunnel(object):
return self._outfds[0].fileno(), self._launch_tunnel, ginfo
def close(self):
if self.closed:
if self._closed:
return
self.closed = True
self._closed = True
logging.debug("Close tunnel PID=%s OUTFD=%s ERRFD=%s",
self.pid,
self.outfd and self.outfd.fileno() or self._outfds,
self.errfd and self.errfd.fileno() or self._errfds)
self._pid,
self._outfd and self._outfd.fileno() or self._outfds,
self._errfd and self._errfd.fileno() or self._errfds)
if self._outfds:
self._outfds[1].close()
self.outfd = None
self._outfd = None
self._outfds = None
if self.errfd:
self.errfd.close()
if self._errfd:
self._errfd.close()
elif self._errfds:
self._errfds[0].close()
self._errfds[1].close()
self.errfd = None
self._errfd = None
self._errfds = None
if self.pid:
os.kill(self.pid, signal.SIGKILL)
os.waitpid(self.pid, 0)
self.pid = None
if self._pid:
os.kill(self._pid, signal.SIGKILL)
os.waitpid(self._pid, 0)
self._pid = None
def get_err_output(self):
errout = ""
while True:
try:
new = self.errfd.recv(1024)
new = self._errfd.recv(1024)
except:
break
@ -187,7 +192,7 @@ class _Tunnel(object):
return errout
def _launch_tunnel(self, ginfo):
if self.closed:
if self._closed:
return -1
host, port, ignore = ginfo.get_conn_host()
@ -255,31 +260,31 @@ class _Tunnel(object):
pid, self._outfds[0].fileno(), self._errfds[0].fileno())
self._errfds[0].setblocking(0)
self.outfd = self._outfds[0]
self.errfd = self._errfds[0]
self._outfd = self._outfds[0]
self._errfd = self._errfds[0]
self._outfds = None
self._errfds = None
self.pid = pid
self._pid = pid
class SSHTunnels(object):
_tunnel_sched = _TunnelScheduler()
def __init__(self, ginfo):
self.ginfo = ginfo
self._tunnels = []
self._ginfo = ginfo
self._locked = False
def open_new(self):
t = _Tunnel()
fd, cb, args = t.open(self.ginfo)
fd, cb, args = t.open(self._ginfo)
self._tunnels.append(t)
self._tunnel_sched.schedule(cb, args)
_tunnel_scheduler.schedule(self._lock, cb, args)
return fd
def close_all(self):
for l in self._tunnels:
l.close()
self._tunnels = []
def get_err_output(self):
errout = ""
@ -287,7 +292,11 @@ class SSHTunnels(object):
errout += l.get_err_output()
return errout
def lock(self, *args, **kwargs):
return self._tunnel_sched.lock(*args, **kwargs)
def _lock(self):
_tunnel_scheduler.lock()
self._locked = True
def unlock(self, *args, **kwargs):
return self._tunnel_sched.unlock(*args, **kwargs)
if self._locked:
_tunnel_scheduler.unlock(*args, **kwargs)
self._locked = False

View File

@ -59,10 +59,11 @@ class Viewer(vmmGObject):
"usb-redirect-error": (GObject.SignalFlags.RUN_FIRST, None, [str]),
}
def __init__(self):
def __init__(self, ginfo):
vmmGObject.__init__(self)
self._display = None
self._tunnels = None
self._ginfo = ginfo
self._tunnels = SSHTunnels(self._ginfo)
self.add_gsettings_handle(
self.config.on_keys_combination_changed(self._refresh_grab_keys))
@ -79,7 +80,8 @@ class Viewer(vmmGObject):
self._display.destroy()
self._display = None
self._reset_tunnels()
if self._tunnels:
self._reset_tunnels()
self._tunnels = None
@ -115,9 +117,6 @@ class Viewer(vmmGObject):
#########################
def _reset_tunnels(self):
if not self._tunnels:
return
errout = self._tunnels.get_err_output()
self._tunnels.close_all()
self._tunnels = None
@ -138,12 +137,11 @@ class Viewer(vmmGObject):
def _get_pixbuf(self):
return self._display.get_pixbuf()
def _open_ginfo(self, ginfo):
if ginfo.need_tunnel():
self._tunnels = SSHTunnels(ginfo)
def _open(self):
if self._ginfo.need_tunnel():
self._open_fd(self._tunnels.open_new())
else:
self._open_host(ginfo)
self._open_host()
def _get_grab_keys(self):
return self._display.get_grab_keys().as_string()
@ -173,7 +171,7 @@ class Viewer(vmmGObject):
def _refresh_keyboard_grab_default(self):
raise NotImplementedError()
def _open_host(self, ginfo):
def _open_host(self):
raise NotImplementedError()
def _open_fd(self, fd):
raise NotImplementedError()
@ -220,8 +218,8 @@ class Viewer(vmmGObject):
def console_get_pixbuf(self):
return self._get_pixbuf()
def console_open_ginfo(self, ginfo):
return self._open_ginfo(ginfo)
def console_open(self):
return self._open()
def console_set_password(self, val):
return self._set_password(val)
@ -270,7 +268,6 @@ class VNCViewer(Viewer):
self._display = None
self._sockfd = None
self._desktop_resolution = None
self._tunnel_unlocked = False
###################
@ -299,17 +296,12 @@ class VNCViewer(Viewer):
self._display.show()
def _unlock_tunnel(self):
if self._tunnels and not self._tunnel_unlocked:
self._tunnels.unlock()
self._tunnel_unlocked = True
def _connected_cb(self, ignore):
self._unlock_tunnel()
self._tunnels.unlock()
self.emit("connected")
def _disconnected_cb(self, ignore):
self._unlock_tunnel()
self._tunnels.unlock()
self.emit("disconnected")
def _desktop_resize(self, src_ignore, w, h):
@ -428,31 +420,31 @@ class VNCViewer(Viewer):
# Connection routines #
#######################
def _open_ginfo(self, *args, **kwargs):
def _open(self, *args, **kwargs):
self._init_widget()
return Viewer._open_ginfo(self, *args, **kwargs)
return Viewer._open(self, *args, **kwargs)
def _open_host(self, ginfo):
host, port, ignore = ginfo.get_conn_host()
def _open_host(self):
host, port, ignore = self._ginfo.get_conn_host()
if not ginfo.gsocket:
if not self._ginfo.gsocket:
logging.debug("VNC connection to %s:%s", host, port)
self._display.open_host(host, port)
return
logging.debug("VNC connecting to socket=%s", ginfo.gsocket)
logging.debug("VNC connecting to socket=%s", self._ginfo.gsocket)
try:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(ginfo.gsocket)
sock.connect(self._ginfo.gsocket)
self._sockfd = sock
except Exception, e:
raise RuntimeError(_("Error opening socket path '%s': %s") %
(ginfo.gsocket, e))
(self._ginfo.gsocket, e))
fd = self._sockfd.fileno()
if fd < 0:
raise RuntimeError((_("Error opening socket path '%s'") %
ginfo.gsocket) + " fd=%s" % fd)
self._ginfo.gsocket) + " fd=%s" % fd)
self._open_fd(fd)
def _open_fd(self, fd):
@ -548,9 +540,6 @@ class SpiceViewer(Viewer):
self._tunnels.unlock()
def _channel_open_fd_request(self, channel, tls_ignore):
if not self._tunnels:
raise SystemError("Got fd request with no configured tunnel!")
logging.debug("Opening tunnel for channel: %s", channel)
channel.connect_after("channel-event", self._fd_channel_event_cb)
@ -563,8 +552,7 @@ class SpiceViewer(Viewer):
if (type(channel) == SpiceClientGLib.MainChannel and
not self._main_channel):
if self._tunnels:
self._tunnels.unlock()
self._tunnels.unlock()
self._main_channel = channel
hid = self._main_channel.connect_after("channel-event",
self._main_channel_event_cb)
@ -658,8 +646,8 @@ class SpiceViewer(Viewer):
return False
return self._main_channel.get_property("agent-connected")
def _open_host(self, ginfo):
host, port, tlsport = ginfo.get_conn_host()
def _open_host(self):
host, port, tlsport = self._ginfo.get_conn_host()
self._create_spice_session()
self._spice_session.set_property("host", str(host))