console: Fix issues with spice and askpass (bz 811346)

Spice opens many FDs to handle different channels (display, usb, sound,
etc.). For remote SSH URIs, this means we launch multiple SSH proceses.
We do so by forking off the process, and when SSH has successfully
authenticated, the data starts flowing.

If using spice + remote SSH w/o SSH keys, you need to put your data
into ssh askpass. askpass wants to own the display for security reasons.

When all the channel requests start coming in, we were launching multiple
ssh processes one after another. This upset askpass and generally
caused havoc in the app.

Add some infrastructure to serialize launching ssh processes. We only
launch the next ssh process if spice/vnc have conclusively connected
or errored out the connection. This makes connection a bit slower
for the non-askpass ssh case (about 1.5 seconds), but will ignore
avoid this oft reported problem.
This commit is contained in:
Cole Robinson 2013-09-06 19:36:09 -04:00
parent 67cc81f6b1
commit 5bf63759b6
2 changed files with 185 additions and 101 deletions

View File

@ -36,6 +36,19 @@ from gi.repository import Gtk
class vmmGObject(GObject.GObject):
_leak_check = True
@staticmethod
def idle_add(func, *args, **kwargs):
"""
Make sure idle functions are run thread safe
"""
def cb():
try:
return func(*args, **kwargs)
except:
print traceback.format_exc()
return False
return GLib.idle_add(cb)
def __init__(self):
GObject.GObject.__init__(self)
self.config = config.running_config
@ -141,18 +154,6 @@ class vmmGObject(GObject.GObject):
self.idle_add(emitwrap, signal, *args)
def idle_add(self, func, *args, **kwargs):
"""
Make sure idle functions are run thread safe
"""
def cb():
try:
return func(*args, **kwargs)
except:
print traceback.format_exc()
return False
return GLib.idle_add(cb)
def timeout_add(self, timeout, func, *args):
"""
Make sure timeout functions are run thread safe

View File

@ -31,10 +31,12 @@ from gi.repository import SpiceClientGLib
import libvirt
import logging
import os
import Queue
import signal
import socket
import logging
import threading
import virtManager.uihelpers as uihelpers
from virtManager.autodrawer import AutoDrawer
@ -114,14 +116,105 @@ class ConnectionInfo(object):
return int(self.gport) == -1
class Tunnel(object):
class _TunnelScheduler(object):
"""
If the user is using Spice + SSH URI + no SSH keys, we need to
serialize connection opening otherwise ssh-askpass gets all angry.
This handles the locking and scheduling.
It's only instantiated once for the whole app, because we serialize
independent of connection, vm, etc.
"""
def __init__(self):
self._thread = threading.Thread(name="Tunnel thread",
target=self._handle_queue,
args=())
self._thread.daemon = True
self._queue = Queue.Queue()
self._lock = threading.Lock()
def _handle_queue(self):
while True:
cb, args, = self._queue.get()
self.lock()
vmmGObject.idle_add(cb, *args)
def schedule(self, cb, *args):
if not self._thread.is_alive():
self._thread.start()
self._queue.put((cb, args))
def lock(self):
self._lock.acquire()
def unlock(self):
self._lock.release()
_tunnel_sched = _TunnelScheduler()
class _Tunnel(object):
def __init__(self):
self.outfd = None
self.errfd = None
self.pid = None
self._outfds = None
self._errfds = None
self.closed = False
def open(self, ginfo):
if self.outfd is not None:
self._outfds = socket.socketpair()
self._errfds = socket.socketpair()
return self._outfds[0].fileno(), self._launch_tunnel, ginfo
def close(self):
if self.closed:
return
self.closed = True
logging.debug("Close tunnel PID=%s OUTFD=%s ERRFD=%s",
self.pid,
self.outfd and self.outfd.fileno() or self._outfds,
self.errfd and self.errfd.fileno() or self._errfds)
if self.outfd:
self.outfd.close()
elif self._outfds:
self._outfds[0].close()
self._outfds[1].close()
self.outfd = None
self._outfds = None
if self.errfd:
self.errfd.close()
elif self._errfds:
self._errfds[0].close()
self._errfds[1].close()
self.errfd = None
self._errfds = None
if self.pid:
os.kill(self.pid, signal.SIGKILL)
os.waitpid(self.pid, 0)
self.pid = None
def get_err_output(self):
errout = ""
while True:
try:
new = self.errfd.recv(1024)
except:
break
if not new:
break
errout += new
return errout
def _launch_tunnel(self, ginfo):
if self.closed:
return -1
host, port, ignore = ginfo.get_conn_host()
@ -168,70 +261,33 @@ class Tunnel(object):
argv_str = reduce(lambda x, y: x + " " + y, argv[1:])
logging.debug("Creating SSH tunnel: %s", argv_str)
fds = socket.socketpair()
errorfds = socket.socketpair()
pid = os.fork()
if pid == 0:
fds[0].close()
errorfds[0].close()
self._outfds[0].close()
self._errfds[0].close()
os.close(0)
os.close(1)
os.close(2)
os.dup(fds[1].fileno())
os.dup(fds[1].fileno())
os.dup(errorfds[1].fileno())
os.dup(self._outfds[1].fileno())
os.dup(self._outfds[1].fileno())
os.dup(self._errfds[1].fileno())
os.execlp(*argv)
os._exit(1) # pylint: disable=W0212
else:
fds[1].close()
errorfds[1].close()
self._outfds[1].close()
self._errfds[1].close()
logging.debug("Tunnel PID=%d OUTFD=%d ERRFD=%d",
pid, fds[0].fileno(), errorfds[0].fileno())
errorfds[0].setblocking(0)
logging.debug("Open tunnel PID=%d OUTFD=%d ERRFD=%d",
pid, self._outfds[0].fileno(), self._errfds[0].fileno())
self._errfds[0].setblocking(0)
self.outfd = fds[0]
self.errfd = errorfds[0]
self.outfd = self._outfds[0]
self.errfd = self._errfds[0]
self._outfds = None
self._errfds = None
self.pid = pid
fd = fds[0].fileno()
if fd < 0:
raise SystemError("can't open a new tunnel: fd=%d" % fd)
return fd
def close(self):
if self.outfd is None:
return
logging.debug("Shutting down tunnel PID=%d OUTFD=%d ERRFD=%d",
self.pid, self.outfd.fileno(),
self.errfd.fileno())
self.outfd.close()
self.outfd = None
self.errfd.close()
self.errfd = None
os.kill(self.pid, signal.SIGKILL)
os.waitpid(self.pid, 0)
self.pid = None
def get_err_output(self):
errout = ""
while True:
try:
new = self.errfd.recv(1024)
except:
break
if not new:
break
errout += new
return errout
class Tunnels(object):
def __init__(self, ginfo):
@ -239,9 +295,11 @@ class Tunnels(object):
self._tunnels = []
def open_new(self):
t = Tunnel()
fd = t.open(self.ginfo)
t = _Tunnel()
fd, cb, args = t.open(self.ginfo)
self._tunnels.append(t)
_tunnel_sched.schedule(cb, args)
return fd
def close_all(self):
@ -254,6 +312,9 @@ class Tunnels(object):
errout += l.get_err_output()
return errout
lock = _tunnel_sched.lock
unlock = _tunnel_sched.unlock
class Viewer(vmmGObject):
def __init__(self, console):
@ -275,6 +336,12 @@ class Viewer(vmmGObject):
def get_pixbuf(self):
return self.display.get_pixbuf()
def open_ginfo(self, ginfo):
if ginfo.need_tunnel():
self.open_fd(self.console.tunnels.open_new())
else:
self.open_host(ginfo)
def get_grab_keys(self):
raise NotImplementedError()
@ -284,10 +351,10 @@ class Viewer(vmmGObject):
def send_keys(self, keys):
raise NotImplementedError()
def open_host(self, ginfo, password=None):
def open_host(self, ginfo):
raise NotImplementedError()
def open_fd(self, fd, password=None):
def open_fd(self, fd):
raise NotImplementedError()
def get_desktop_resolution(self):
@ -306,6 +373,8 @@ class VNCViewer(Viewer):
# Last noticed desktop resolution
self.desktop_resolution = None
self._tunnel_unlocked = False
def init_widget(self):
self.set_grab_keys()
@ -320,18 +389,32 @@ class VNCViewer(Viewer):
self.display.set_pointer_grab(True)
self.display.connect("vnc-pointer-grab", self.console.pointer_grabbed)
self.display.connect("vnc-pointer-ungrab", self.console.pointer_ungrabbed)
self.display.connect("vnc-pointer-ungrab",
self.console.pointer_ungrabbed)
self.display.connect("vnc-auth-credential", self._auth_credential)
self.display.connect("vnc-initialized",
lambda src: self.console.connected())
self.display.connect("vnc-disconnected",
lambda src: self.console.disconnected())
self.display.connect("vnc-initialized", self._connected_cb)
self.display.connect("vnc-disconnected", self._disconnected_cb)
self.display.connect("vnc-desktop-resize", self._desktop_resize)
self.display.connect("focus-in-event", self.console.viewer_focus_changed)
self.display.connect("focus-out-event", self.console.viewer_focus_changed)
self.display.connect("focus-in-event",
self.console.viewer_focus_changed)
self.display.connect("focus-out-event",
self.console.viewer_focus_changed)
self.display.show()
def _unlock_tunnel(self):
if self.console.tunnels and not self._tunnel_unlocked:
self.console.tunnels.unlock()
self._tunnel_unlocked = True
def _connected_cb(self, ignore):
self._unlock_tunnel()
self.console.connected()
def _disconnected_cb(self, ignore):
self._unlock_tunnel()
self.console.disconnected()
def get_grab_keys(self):
return self.display.get_grab_keys().as_string()
@ -421,7 +504,7 @@ class VNCViewer(Viewer):
def is_open(self):
return self.display.is_open()
def open_host(self, ginfo, password=None):
def open_host(self, ginfo):
host, port, ignore = ginfo.get_conn_host()
if not ginfo.gsocket:
@ -444,8 +527,7 @@ class VNCViewer(Viewer):
ginfo.gsocket) + " fd=%s" % fd)
self.open_fd(fd)
def open_fd(self, fd, password=None):
ignore = password
def open_fd(self, fd):
self.display.open_fd(fd)
def set_credential_username(self, cred):
@ -469,8 +551,10 @@ class SpiceViewer(Viewer):
self.console.refresh_scaling()
self.display.realize()
self.display.connect("mouse-grab", lambda src, g: g and self.console.pointer_grabbed(src))
self.display.connect("mouse-grab", lambda src, g: g or self.console.pointer_ungrabbed(src))
self.display.connect("mouse-grab",
lambda src, g: g and self.console.pointer_grabbed(src))
self.display.connect("mouse-grab",
lambda src, g: g or self.console.pointer_ungrabbed(src))
self.display.connect("focus-in-event",
self.console.viewer_focus_changed)
@ -534,11 +618,19 @@ class SpiceViewer(Viewer):
logging.debug("Spice channel event error: %s", event)
self.console.disconnected()
def _fd_channel_event_cb(self, channel, event):
# When we see any event from the channel, release the
# associated tunnel lock
channel.disconnect_by_func(self._fd_channel_event_cb)
self.console.tunnels.unlock()
def _channel_open_fd_request(self, channel, tls_ignore):
if not self.console.tunnels:
raise SystemError("Got fd request with no configured tunnel!")
logging.debug("Opening tunnel for channel: %s", channel)
channel.connect_after("channel-event", self._fd_channel_event_cb)
fd = self.console.tunnels.open_new()
channel.open_fd(fd)
@ -547,6 +639,8 @@ class SpiceViewer(Viewer):
self._channel_open_fd_request)
if type(channel) == SpiceClientGLib.MainChannel:
if self.console.tunnels:
self.console.tunnels.unlock()
channel.connect_after("channel-event", self._main_channel_event_cb)
return
@ -584,6 +678,9 @@ class SpiceViewer(Viewer):
gtk_session = SpiceClientGtk.GtkSession.get(self.spice_session)
gtk_session.set_property("auto-clipboard", True)
GObject.GObject.connect(self.spice_session, "channel-new",
self._channel_new_cb)
self.usbdev_manager = SpiceClientGLib.UsbDeviceManager.get(
self.spice_session)
self.usbdev_manager.connect("auto-connect-failed",
@ -595,26 +692,19 @@ class SpiceViewer(Viewer):
if autoredir:
gtk_session.set_property("auto-usbredir", True)
def open_host(self, ginfo, password=None):
def open_host(self, ginfo):
host, port, tlsport = ginfo.get_conn_host()
self._create_spice_session()
self.spice_session.set_property("host", str(host))
self.spice_session.set_property("port", str(port))
if tlsport:
self.spice_session.set_property("tls-port", str(tlsport))
if password:
self.spice_session.set_property("password", password)
GObject.GObject.connect(self.spice_session, "channel-new",
self._channel_new_cb)
self.spice_session.connect()
def open_fd(self, fd, password=None):
def open_fd(self, fd):
self._create_spice_session()
if password:
self.spice_session.set_property("password", password)
GObject.GObject.connect(self.spice_session, "channel-new",
self._channel_new_cb)
self.spice_session.open_fd(fd)
def set_credential_password(self, cred):
@ -1254,15 +1344,8 @@ class vmmConsolePages(vmmGObjectUI):
self.set_enable_accel()
if ginfo.need_tunnel():
if self.tunnels:
# Tunnel already open, no need to continue
return
self.tunnels = Tunnels(ginfo)
self.viewer.open_fd(self.tunnels.open_new())
else:
self.viewer.open_host(ginfo)
self.viewer.open_ginfo(ginfo)
except Exception, e:
logging.exception("Error connection to graphical console")
self.activate_unavailable_page(