some source code about twisted

This commit is contained in:
JamesonHuang 2015-06-15 22:20:58 +08:00
parent 8791e644e5
commit dff86defdb
1113 changed files with 407155 additions and 0 deletions

View File

@ -0,0 +1,9 @@
*.egg-info/
*.o
*.py[co]
*.so
_trial_temp*/
build/
dropin.cache
doc/
docs/_build/

View File

@ -0,0 +1,21 @@
Contributing to Twisted
=======================
As an open source project, Twisted welcomes contributions of many forms.
Examples of contributions include:
* Code patches
* Documentation improvements
* Bug reports and patch reviews
Extensive contribution guidelines are available online at:
https://twistedmatrix.com/trac/wiki/ContributingToTwistedLabs
**Warning: pull requests are ignored!** File a ticket at:
https://twistedmatrix.com/trac/newticket
Twisted uses Trac to keep track of bugs, feature requests, and associated
patches because GitHub doesn't provide adequate tooling for its community.

View File

@ -0,0 +1,32 @@
Requirements
Python 2.6 or 2.7.
Zope Interface 3.6.0 or better (http://pypi.python.org/pypi/zope.interface)
pyOpenSSL (<http://launchpad.net/pyopenssl>) is required for any SSL APIs.
Version 0.10 or newer is required.
On Windows pywin32 (<http://sourceforge.net/projects/pywin32/files/>) is
required. Build 215 or later is highly recommended for reliable operation
(this is already included in ActivePython).
If you would like to use Trial's subunit reporter, then you will need to
install Subunit 0.0.2 or later (https://launchpad.net/subunit).
Installation
* Debian and Ubuntu
Packages are included in the main distribution.
* FreeBSD, Gentoo
Twisted is in their package repositories.
* Win32
Installers are available from http://twistedmatrix.com/
* Other
As with other Python packages, the standard way of installing from source
is:
python setup.py install

View File

@ -0,0 +1,67 @@
Copyright (c) 2001-2015
Allen Short
Andy Gayton
Andrew Bennetts
Antoine Pitrou
Apple Computer, Inc.
Ashwini Oruganti
Benjamin Bruheim
Bob Ippolito
Canonical Limited
Christopher Armstrong
David Reid
Donovan Preston
Eric Mangold
Eyal Lotem
Google Inc.
Hawkie Owl
Hybrid Logic Ltd.
Hynek Schlawack
Itamar Turner-Trauring
James Knight
Jason A. Mobarak
Jean-Paul Calderone
Jessica McKellar
Jonathan Jacobs
Jonathan Lange
Jonathan D. Simms
Jürgen Hermann
Julian Berman
Kevin Horn
Kevin Turner
Laurens Van Houtven
Mary Gardiner
Matthew Lefkowitz
Massachusetts Institute of Technology
Moshe Zadka
Paul Swartz
Pavel Pergamenshchik
Ralph Meijer
Richard Wall
Sean Riley
Software Freedom Conservancy
Tavendo GmbH
Travis B. Hartwell
Thijs Triemstra
Thomas Herve
Timothy Allen
Tom Prince
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,114 @@
Twisted 15.2.1
Quote of the Release:
<hynek> is there a race condition in threading tests? how could that happen :>
For information on what's new in Twisted 15.2.1, see the NEWS file that comes
with the distribution.
What is this?
=============
Twisted is an event-based framework for internet applications. It includes
modules for many different purposes, including the following:
- twisted.application
A "Service" system that allows you to organize your application in
hierarchies with well-defined startup and dependency semantics,
- twisted.cred
A general credentials and authentication system that facilitates
pluggable authentication backends,
- twisted.enterprise
Asynchronous database access, compatible with any Python DBAPI2.0
modules,
- twisted.internet
Low-level asynchronous networking APIs that allow you to define
your own protocols that run over certain transports,
- twisted.manhole
A tool for remote debugging of your services which gives you a
Python interactive interpreter,
- twisted.protocols
Basic protocol implementations and helpers for your own protocol
implementations,
- twisted.python
A large set of utilities for Python tricks, reflection, text
processing, and anything else,
- twisted.spread
A secure, fast remote object system,
- twisted.trial
A unit testing framework that integrates well with Twisted-based code.
Twisted supports integration of the Win32, Tk, GTK+ and GTK+ 2 event loops
with its main event loop. There is experimental support for Mac OS X and
wxPython event loop integration, which you use at your peril.
For more information, visit http://www.twistedmatrix.com, or join the list
at http://twistedmatrix.com/cgi-bin/mailman/listinfo/twisted-python
There are many official Twisted subprojects, including clients and
servers for web, mail, DNS, and more. You can find out more about
these projects at http://twistedmatrix.com/trac/wiki/TwistedProjects
Installing
==========
Instructions for installing this software are in INSTALL.
Unit Tests
==========
See our unit tests run proving that the software is BugFree(TM):
% trial twisted
Some of these tests may fail if you
* don't have the dependancies required for a particular subsystem installed,
* have a firewall blocking some ports (or things like Multicast, which Linux
NAT has shown itself to do), or
* run them as root.
Documentation and Support
=========================
Twisted's documentation is available from the Twisted Matrix website:
http://twistedmatrix.com/documents/current/
This documentation contains how-tos, code examples, and an API reference.
Help is also available on the Twisted mailing list:
http://twistedmatrix.com/cgi-bin/mailman/listinfo/twisted-python
There is also a pair of very lively IRC channels, #twisted (for general
Twisted questions) and #twisted.web (for Twisted Web), on chat.freenode.net.
Copyright
=========
All of the code in this distribution is Copyright (c) 2001-2015
Twisted Matrix Laboratories.
Twisted is made available under the MIT license. The included
LICENSE file describes this in detail.
Warranty
========
THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER
EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
TO THE USE OF THIS SOFTWARE IS WITH YOU.
IN NO EVENT WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY
AND/OR REDISTRIBUTE THE LIBRARY, BE LIABLE TO YOU FOR ANY DAMAGES, EVEN IF
SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
DAMAGES.
Again, see the included LICENSE file for specific legal details.

View File

@ -0,0 +1,19 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# This makes sure that users don't have to set up their environment
# specially in order to run these programs from bin/.
# This helper is shared by many different actual scripts. It is not intended to
# be packaged or installed, it is only a developer convenience. By the time
# Twisted is actually installed somewhere, the environment should already be set
# up properly without the help of this tool.
import sys, os
path = os.path.abspath(sys.argv[0])
while os.path.dirname(path) != path:
if os.path.exists(os.path.join(path, 'twisted', '__init__.py')):
sys.path.insert(0, path)
break
path = os.path.dirname(path)

View File

@ -0,0 +1,15 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys, os
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
sys.path.insert(0, extra)
try:
import _preamble
except ImportError:
sys.exc_clear()
sys.path.remove(extra)
from twisted.conch.scripts.cftp import run
run()

View File

@ -0,0 +1,15 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys, os
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
sys.path.insert(0, extra)
try:
import _preamble
except ImportError:
sys.exc_clear()
sys.path.remove(extra)
from twisted.conch.scripts.ckeygen import run
run()

View File

@ -0,0 +1,15 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys, os
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
sys.path.insert(0, extra)
try:
import _preamble
except ImportError:
sys.exc_clear()
sys.path.remove(extra)
from twisted.conch.scripts.conch import run
run()

View File

@ -0,0 +1,15 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys, os
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
sys.path.insert(0, extra)
try:
import _preamble
except ImportError:
sys.exc_clear()
sys.path.remove(extra)
from twisted.conch.scripts.tkconch import run
run()

View File

@ -0,0 +1,16 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys, os
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
sys.path.insert(0, extra)
try:
import _preamble
except ImportError:
sys.exc_clear()
sys.path.remove(extra)
from twisted.lore.scripts.lore import run
run()

View File

@ -0,0 +1,20 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This script attempts to send some email.
"""
import sys, os
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
sys.path.insert(0, extra)
try:
import _preamble
except ImportError:
sys.exc_clear()
sys.path.remove(extra)
from twisted.mail.scripts import mailmail
mailmail.run()

View File

@ -0,0 +1,16 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This script runs GtkManhole, a client for Twisted.Manhole
"""
import sys
try:
import _preamble
except ImportError:
sys.exc_clear()
from twisted.scripts import manhole
manhole.run()

View File

@ -0,0 +1,12 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import sys
try:
import _preamble
except ImportError:
sys.exc_clear()
from twisted.scripts.htmlizer import run
run()

View File

@ -0,0 +1,16 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
tap2deb
"""
import sys
try:
__import__('_preamble')
except ImportError:
sys.exc_clear()
from twisted.scripts import tap2deb
tap2deb.run()

View File

@ -0,0 +1,19 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# based off the tap2deb code
# tap2rpm built by Sean Reifschneider, <jafo@tummy.com>
"""
tap2rpm
"""
import sys
try:
import _preamble
except ImportError:
sys.exc_clear()
from twisted.scripts import tap2rpm
tap2rpm.run()

View File

@ -0,0 +1,18 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import os, sys
try:
import _preamble
except ImportError:
sys.exc_clear()
# begin chdir armor
sys.path[:] = map(os.path.abspath, sys.path)
# end chdir armor
sys.path.insert(0, os.path.abspath(os.getcwd()))
from twisted.scripts.trial import run
run()

View File

@ -0,0 +1,14 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import os, sys
try:
import _preamble
except ImportError:
sys.exc_clear()
sys.path.insert(0, os.path.abspath(os.getcwd()))
from twisted.scripts.twistd import run
run()

View File

@ -0,0 +1,76 @@
#!/usr/bin/env python
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Distutils installer for Twisted.
"""
try:
# Load setuptools, to build a specific source package
import setuptools
# Tell Twisted not to enforce zope.interface requirement on import, since
# we're going to have to import twisted.python.dist and can rely on
# setuptools to install dependencies.
setuptools._TWISTED_NO_CHECK_REQUIREMENTS = True
except ImportError:
pass
import os
import sys
def main(args):
"""
Invoke twisted.python.dist with the appropriate metadata about the
Twisted package.
"""
# On Python 3, use setup3.py until Python 3 port is done:
if sys.version_info[0] > 2:
import setup3
setup3.main()
return
if os.path.exists('twisted'):
sys.path.insert(0, '.')
setup_args = {}
if 'setuptools' in sys.modules:
from pkg_resources import parse_requirements
requirements = ["zope.interface >= 3.6.0"]
try:
list(parse_requirements(requirements))
except:
print("""You seem to be running a very old version of setuptools.
This version of setuptools has a bug parsing dependencies, so automatic
dependency resolution is disabled.
""")
else:
setup_args['install_requires'] = requirements
setup_args['include_package_data'] = True
setup_args['zip_safe'] = False
from twisted.python.dist import (
STATIC_PACKAGE_METADATA, getDataFiles, getExtensions, getAllScripts,
getPackages, setup, _EXTRAS_REQUIRE)
scripts = getAllScripts()
setup_args.update(dict(
packages=getPackages('twisted'),
conditionalExtensions=getExtensions(),
scripts=scripts,
extras_require=_EXTRAS_REQUIRE,
data_files=getDataFiles('twisted'),
**STATIC_PACKAGE_METADATA))
setup(**setup_args)
if __name__ == "__main__":
try:
main(sys.argv[1:])
except KeyboardInterrupt:
sys.exit(1)

View File

@ -0,0 +1,51 @@
#!/usr/bin/env python3.3
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# This is a temporary helper to be able to build and install distributions of
# Twisted on/for Python 3. Once all of Twisted has been ported, it should go
# away and setup.py should work for either Python 2 or Python 3.
from __future__ import division, absolute_import
import sys
import os
from distutils.command.sdist import sdist
class DisabledSdist(sdist):
"""
A version of the sdist command that does nothing.
"""
def run(self):
sys.stderr.write(
"The sdist command only works with Python 2 at the moment.\n")
sys.exit(1)
def main():
try:
from setuptools import setup
except ImportError:
from distutils.core import setup
# Make sure the to-be-installed version of Twisted is used, if available,
# since we're importing from it:
if os.path.exists('twisted'):
sys.path.insert(0, '.')
from twisted.python.dist3 import modulesToInstall
from twisted.python.dist import STATIC_PACKAGE_METADATA
args = STATIC_PACKAGE_METADATA.copy()
args['install_requires'] = ["zope.interface >= 4.0.2"]
args['py_modules'] = modulesToInstall
args['cmdclass'] = {'sdist': DisabledSdist}
setup(**args)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,71 @@
# -*- test-case-name: twisted -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted: The Framework Of Your Internet.
"""
def _checkRequirements():
# Don't allow the user to run a version of Python we don't support.
import sys
version = getattr(sys, "version_info", (0,))
if version < (2, 6):
raise ImportError("Twisted requires Python 2.6 or later.")
if version < (3, 0):
required = "3.6.0"
else:
required = "4.0.0"
if ("setuptools" in sys.modules and
getattr(sys.modules["setuptools"],
"_TWISTED_NO_CHECK_REQUIREMENTS", None) is not None):
# Skip requirement checks, setuptools ought to take care of installing
# the dependencies.
return
# Don't allow the user to run with a version of zope.interface we don't
# support.
required = "Twisted requires zope.interface %s or later" % (required,)
try:
from zope import interface
except ImportError:
# It isn't installed.
raise ImportError(required + ": no module named zope.interface.")
except:
# It is installed but not compatible with this version of Python.
raise ImportError(required + ".")
try:
# Try using the API that we need, which only works right with
# zope.interface 3.6 (or 4.0 on Python 3)
class IDummy(interface.Interface):
pass
@interface.implementer(IDummy)
class Dummy(object):
pass
except TypeError:
# It is installed but not compatible with this version of Python.
raise ImportError(required + ".")
_checkRequirements()
# Ensure compat gets imported
from twisted.python import compat
# setup version
from twisted._version import version
__version__ = version.short()
del compat
# Deprecating lore.
from twisted.python.versions import Version
from twisted.python.deprecate import deprecatedModuleAttribute
deprecatedModuleAttribute(
Version("Twisted", 14, 0, 0),
"Use Sphinx instead.",
"twisted", "lore")

View File

@ -0,0 +1,11 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# This is an auto-generated file. Do not edit it.
"""
Provides Twisted version information.
"""
from twisted.python import versions
version = versions.Version('twisted', 15, 2, 1)

View File

@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Configuration objects for Twisted Applications.
"""

View File

@ -0,0 +1,678 @@
# -*- test-case-name: twisted.test.test_application,twisted.test.test_twistd -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import print_function
import sys
import os
import pdb
import getpass
import traceback
import signal
from operator import attrgetter
from twisted.python import runtime, log, usage, failure, util, logfile
from twisted.python.reflect import qual, namedAny
from twisted.python.log import ILogObserver
from twisted.persisted import sob
from twisted.application import service, reactors
from twisted.internet import defer
from twisted import copyright, plugin
# Expose the new implementation of installReactor at the old location.
from twisted.application.reactors import installReactor
from twisted.application.reactors import NoSuchReactor
class _BasicProfiler(object):
"""
@ivar saveStats: if C{True}, save the stats information instead of the
human readable format
@type saveStats: C{bool}
@ivar profileOutput: the name of the file use to print profile data.
@type profileOutput: C{str}
"""
def __init__(self, profileOutput, saveStats):
self.profileOutput = profileOutput
self.saveStats = saveStats
def _reportImportError(self, module, e):
"""
Helper method to report an import error with a profile module. This
has to be explicit because some of these modules are removed by
distributions due to them being non-free.
"""
s = "Failed to import module %s: %s" % (module, e)
s += """
This is most likely caused by your operating system not including
the module due to it being non-free. Either do not use the option
--profile, or install the module; your operating system vendor
may provide it in a separate package.
"""
raise SystemExit(s)
class ProfileRunner(_BasicProfiler):
"""
Runner for the standard profile module.
"""
def run(self, reactor):
"""
Run reactor under the standard profiler.
"""
try:
import profile
except ImportError as e:
self._reportImportError("profile", e)
p = profile.Profile()
p.runcall(reactor.run)
if self.saveStats:
p.dump_stats(self.profileOutput)
else:
tmp, sys.stdout = sys.stdout, open(self.profileOutput, 'a')
try:
p.print_stats()
finally:
sys.stdout, tmp = tmp, sys.stdout
tmp.close()
class HotshotRunner(_BasicProfiler):
"""
Runner for the hotshot profile module.
"""
def run(self, reactor):
"""
Run reactor under the hotshot profiler.
"""
try:
import hotshot.stats
except (ImportError, SystemExit) as e:
# Certain versions of Debian (and Debian derivatives) raise
# SystemExit when importing hotshot if the "non-free" profiler
# module is not installed. Someone eventually recognized this
# as a bug and changed the Debian packaged Python to raise
# ImportError instead. Handle both exception types here in
# order to support the versions of Debian which have this
# behavior. The bug report which prompted the introduction of
# this highly undesirable behavior should be available online at
# <http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=334067>.
# There seems to be no corresponding bug report which resulted
# in the behavior being removed. -exarkun
self._reportImportError("hotshot", e)
# this writes stats straight out
p = hotshot.Profile(self.profileOutput)
p.runcall(reactor.run)
if self.saveStats:
# stats are automatically written to file, nothing to do
return
else:
s = hotshot.stats.load(self.profileOutput)
s.strip_dirs()
s.sort_stats(-1)
s.stream = open(self.profileOutput, 'w')
s.print_stats()
s.stream.close()
class CProfileRunner(_BasicProfiler):
"""
Runner for the cProfile module.
"""
def run(self, reactor):
"""
Run reactor under the cProfile profiler.
"""
try:
import cProfile
import pstats
except ImportError as e:
self._reportImportError("cProfile", e)
p = cProfile.Profile()
p.runcall(reactor.run)
if self.saveStats:
p.dump_stats(self.profileOutput)
else:
stream = open(self.profileOutput, 'w')
s = pstats.Stats(p, stream=stream)
s.strip_dirs()
s.sort_stats(-1)
s.print_stats()
stream.close()
class AppProfiler(object):
"""
Class which selects a specific profile runner based on configuration
options.
@ivar profiler: the name of the selected profiler.
@type profiler: C{str}
"""
profilers = {"profile": ProfileRunner, "hotshot": HotshotRunner,
"cprofile": CProfileRunner}
def __init__(self, options):
saveStats = options.get("savestats", False)
profileOutput = options.get("profile", None)
self.profiler = options.get("profiler", "hotshot").lower()
if self.profiler in self.profilers:
profiler = self.profilers[self.profiler](profileOutput, saveStats)
self.run = profiler.run
else:
raise SystemExit("Unsupported profiler name: %s" %
(self.profiler,))
class AppLogger(object):
"""
An L{AppLogger} attaches the configured log observer specified on the
commandline to a L{ServerOptions} object or the custom L{ILogObserver}.
@ivar _logfilename: The name of the file to which to log, if other than the
default.
@type _logfilename: C{str}
@ivar _observerFactory: Callable object that will create a log observer, or
None.
@ivar _observer: log observer added at C{start} and removed at C{stop}.
@type _observer: C{callable}
"""
_observer = None
def __init__(self, options):
"""
Initialize an L{AppLogger} with a L{ServerOptions}.
"""
self._logfilename = options.get("logfile", "")
self._observerFactory = options.get("logger") or None
def start(self, application):
"""
Initialize the global logging system for the given application.
If a custom logger was specified on the command line it will be used.
If not, and an L{ILogObserver} component has been set on
C{application}, then it will be used as the log observer. Otherwise a
log observer will be created based on the command-line options for
built-in loggers (e.g. C{--logfile}).
@param application: The application on which to check for an
L{ILogObserver}.
@type application: L{twisted.python.components.Componentized}
"""
if self._observerFactory is not None:
observer = self._observerFactory()
else:
observer = application.getComponent(ILogObserver, None)
if observer is None:
observer = self._getLogObserver()
self._observer = observer
log.startLoggingWithObserver(self._observer)
self._initialLog()
def _initialLog(self):
"""
Print twistd start log message.
"""
from twisted.internet import reactor
log.msg("twistd %s (%s %s) starting up." % (
copyright.version, sys.executable, runtime.shortPythonVersion())
)
log.msg('reactor class: %s.' % (qual(reactor.__class__),))
def _getLogObserver(self):
"""
Create a log observer to be added to the logging system before running
this application.
"""
if self._logfilename == '-' or not self._logfilename:
logFile = sys.stdout
else:
logFile = logfile.LogFile.fromFullPath(self._logfilename)
return log.FileLogObserver(logFile).emit
def stop(self):
"""
Remove all log observers previously set up by L{AppLogger.start}.
"""
log.msg("Server Shut Down.")
if self._observer is not None:
log.removeObserver(self._observer)
self._observer = None
def fixPdb():
def do_stop(self, arg):
self.clear_all_breaks()
self.set_continue()
from twisted.internet import reactor
reactor.callLater(0, reactor.stop)
return 1
def help_stop(self):
print("stop - Continue execution, then cleanly shutdown the twisted "
"reactor.")
def set_quit(self):
os._exit(0)
pdb.Pdb.set_quit = set_quit
pdb.Pdb.do_stop = do_stop
pdb.Pdb.help_stop = help_stop
def runReactorWithLogging(config, oldstdout, oldstderr, profiler=None,
reactor=None):
"""
Start the reactor, using profiling if specified by the configuration, and
log any error happening in the process.
@param config: configuration of the twistd application.
@type config: L{ServerOptions}
@param oldstdout: initial value of C{sys.stdout}.
@type oldstdout: C{file}
@param oldstderr: initial value of C{sys.stderr}.
@type oldstderr: C{file}
@param profiler: object used to run the reactor with profiling.
@type profiler: L{AppProfiler}
@param reactor: The reactor to use. If C{None}, the global reactor will
be used.
"""
if reactor is None:
from twisted.internet import reactor
try:
if config['profile']:
if profiler is not None:
profiler.run(reactor)
elif config['debug']:
sys.stdout = oldstdout
sys.stderr = oldstderr
if runtime.platformType == 'posix':
signal.signal(signal.SIGUSR2, lambda *args: pdb.set_trace())
signal.signal(signal.SIGINT, lambda *args: pdb.set_trace())
fixPdb()
pdb.runcall(reactor.run)
else:
reactor.run()
except:
if config['nodaemon']:
file = oldstdout
else:
file = open("TWISTD-CRASH.log", "a")
traceback.print_exc(file=file)
file.flush()
def getPassphrase(needed):
if needed:
return getpass.getpass('Passphrase: ')
else:
return None
def getSavePassphrase(needed):
if needed:
return util.getPassword("Encryption passphrase: ")
else:
return None
class ApplicationRunner(object):
"""
An object which helps running an application based on a config object.
Subclass me and implement preApplication and postApplication
methods. postApplication generally will want to run the reactor
after starting the application.
@ivar config: The config object, which provides a dict-like interface.
@ivar application: Available in postApplication, but not
preApplication. This is the application object.
@ivar profilerFactory: Factory for creating a profiler object, able to
profile the application if options are set accordingly.
@ivar profiler: Instance provided by C{profilerFactory}.
@ivar loggerFactory: Factory for creating object responsible for logging.
@ivar logger: Instance provided by C{loggerFactory}.
"""
profilerFactory = AppProfiler
loggerFactory = AppLogger
def __init__(self, config):
self.config = config
self.profiler = self.profilerFactory(config)
self.logger = self.loggerFactory(config)
def run(self):
"""
Run the application.
"""
self.preApplication()
self.application = self.createOrGetApplication()
self.logger.start(self.application)
self.postApplication()
self.logger.stop()
def startReactor(self, reactor, oldstdout, oldstderr):
"""
Run the reactor with the given configuration. Subclasses should
probably call this from C{postApplication}.
@see: L{runReactorWithLogging}
"""
runReactorWithLogging(
self.config, oldstdout, oldstderr, self.profiler, reactor)
def preApplication(self):
"""
Override in subclass.
This should set up any state necessary before loading and
running the Application.
"""
raise NotImplementedError()
def postApplication(self):
"""
Override in subclass.
This will be called after the application has been loaded (so
the C{application} attribute will be set). Generally this
should start the application and run the reactor.
"""
raise NotImplementedError()
def createOrGetApplication(self):
"""
Create or load an Application based on the parameters found in the
given L{ServerOptions} instance.
If a subcommand was used, the L{service.IServiceMaker} that it
represents will be used to construct a service to be added to
a newly-created Application.
Otherwise, an application will be loaded based on parameters in
the config.
"""
if self.config.subCommand:
# If a subcommand was given, it's our responsibility to create
# the application, instead of load it from a file.
# loadedPlugins is set up by the ServerOptions.subCommands
# property, which is iterated somewhere in the bowels of
# usage.Options.
plg = self.config.loadedPlugins[self.config.subCommand]
ser = plg.makeService(self.config.subOptions)
application = service.Application(plg.tapname)
ser.setServiceParent(application)
else:
passphrase = getPassphrase(self.config['encrypted'])
application = getApplication(self.config, passphrase)
return application
def getApplication(config, passphrase):
s = [(config[t], t)
for t in ['python', 'source', 'file'] if config[t]][0]
filename, style = s[0], {'file': 'pickle'}.get(s[1], s[1])
try:
log.msg("Loading %s..." % filename)
application = service.loadApplication(filename, style, passphrase)
log.msg("Loaded.")
except Exception as e:
s = "Failed to load application: %s" % e
if isinstance(e, KeyError) and e.args[0] == "application":
s += """
Could not find 'application' in the file. To use 'twistd -y', your .tac
file must create a suitable object (e.g., by calling service.Application())
and store it in a variable named 'application'. twistd loads your .tac file
and scans the global variables for one of this name.
Please read the 'Using Application' HOWTO for details.
"""
traceback.print_exc(file=log.logfile)
log.msg(s)
log.deferr()
sys.exit('\n' + s + '\n')
return application
def _reactorAction():
return usage.CompleteList([r.shortName for r in
reactors.getReactorTypes()])
class ReactorSelectionMixin:
"""
Provides options for selecting a reactor to install.
If a reactor is installed, the short name which was used to locate it is
saved as the value for the C{"reactor"} key.
"""
compData = usage.Completions(
optActions={"reactor": _reactorAction})
messageOutput = sys.stdout
_getReactorTypes = staticmethod(reactors.getReactorTypes)
def opt_help_reactors(self):
"""
Display a list of possibly available reactor names.
"""
rcts = sorted(self._getReactorTypes(), key=attrgetter('shortName'))
for r in rcts:
self.messageOutput.write(' %-4s\t%s\n' %
(r.shortName, r.description))
raise SystemExit(0)
def opt_reactor(self, shortName):
"""
Which reactor to use (see --help-reactors for a list of possibilities)
"""
# Actually actually actually install the reactor right at this very
# moment, before any other code (for example, a sub-command plugin)
# runs and accidentally imports and installs the default reactor.
#
# This could probably be improved somehow.
try:
installReactor(shortName)
except NoSuchReactor:
msg = ("The specified reactor does not exist: '%s'.\n"
"See the list of available reactors with "
"--help-reactors" % (shortName,))
raise usage.UsageError(msg)
except Exception as e:
msg = ("The specified reactor cannot be used, failed with error: "
"%s.\nSee the list of available reactors with "
"--help-reactors" % (e,))
raise usage.UsageError(msg)
else:
self["reactor"] = shortName
opt_r = opt_reactor
class ServerOptions(usage.Options, ReactorSelectionMixin):
longdesc = ("twistd reads a twisted.application.service.Application out "
"of a file and runs it.")
optFlags = [['savestats', None,
"save the Stats object rather than the text output of "
"the profiler."],
['no_save', 'o', "do not save state on shutdown"],
['encrypted', 'e',
"The specified tap/aos file is encrypted."]]
optParameters = [['logfile', 'l', None,
"log to a specified file, - for stdout"],
['logger', None, None,
"A fully-qualified name to a log observer factory to "
"use for the initial log observer. Takes precedence "
"over --logfile and --syslog (when available)."],
['profile', 'p', None,
"Run in profile mode, dumping results to specified "
"file."],
['profiler', None, "hotshot",
"Name of the profiler to use (%s)." %
", ".join(AppProfiler.profilers)],
['file', 'f', 'twistd.tap',
"read the given .tap file"],
['python', 'y', None,
"read an application from within a Python file "
"(implies -o)"],
['source', 's', None,
"Read an application from a .tas file (AOT format)."],
['rundir', 'd', '.',
'Change to a supplied directory before running']]
compData = usage.Completions(
mutuallyExclusive=[("file", "python", "source")],
optActions={"file": usage.CompleteFiles("*.tap"),
"python": usage.CompleteFiles("*.(tac|py)"),
"source": usage.CompleteFiles("*.tas"),
"rundir": usage.CompleteDirs()}
)
_getPlugins = staticmethod(plugin.getPlugins)
def __init__(self, *a, **kw):
self['debug'] = False
usage.Options.__init__(self, *a, **kw)
def opt_debug(self):
"""
Run the application in the Python Debugger (implies nodaemon),
sending SIGUSR2 will drop into debugger
"""
defer.setDebugging(True)
failure.startDebugMode()
self['debug'] = True
opt_b = opt_debug
def opt_spew(self):
"""
Print an insanely verbose log of everything that happens.
Useful when debugging freezes or locks in complex code."""
sys.settrace(util.spewer)
try:
import threading
except ImportError:
return
threading.settrace(util.spewer)
def parseOptions(self, options=None):
if options is None:
options = sys.argv[1:] or ["--help"]
usage.Options.parseOptions(self, options)
def postOptions(self):
if self.subCommand or self['python']:
self['no_save'] = True
if self['logger'] is not None:
try:
self['logger'] = namedAny(self['logger'])
except Exception as e:
raise usage.UsageError("Logger '%s' could not be imported: %s"
% (self['logger'], e))
def subCommands(self):
plugins = self._getPlugins(service.IServiceMaker)
self.loadedPlugins = {}
for plug in sorted(plugins, key=attrgetter('tapname')):
self.loadedPlugins[plug.tapname] = plug
yield (plug.tapname,
None,
# Avoid resolving the options attribute right away, in case
# it's a property with a non-trivial getter (eg, one which
# imports modules).
lambda plug=plug: plug.options(),
plug.description)
subCommands = property(subCommands)
def run(runApp, ServerOptions):
config = ServerOptions()
try:
config.parseOptions()
except usage.error as ue:
print(config)
print("%s: %s" % (sys.argv[0], ue))
else:
runApp(config)
def convertStyle(filein, typein, passphrase, fileout, typeout, encrypt):
application = service.loadApplication(filein, typein, passphrase)
sob.IPersistable(application).setStyle(typeout)
passphrase = getSavePassphrase(encrypt)
if passphrase:
fileout = None
sob.IPersistable(application).save(filename=fileout, passphrase=passphrase)
def startApplication(application, save):
from twisted.internet import reactor
service.IService(application).startService()
if save:
p = sob.IPersistable(application)
reactor.addSystemEventTrigger('after', 'shutdown', p.save, 'shutdown')
reactor.addSystemEventTrigger('before', 'shutdown',
service.IService(application).stopService)

View File

@ -0,0 +1,396 @@
# -*- test-case-name: twisted.application.test.test_internet,twisted.test.test_application,twisted.test.test_cooperator -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Reactor-based Services
Here are services to run clients, servers and periodic services using
the reactor.
If you want to run a server service, L{StreamServerEndpointService} defines a
service that can wrap an arbitrary L{IStreamServerEndpoint
<twisted.internet.interfaces.IStreamServerEndpoint>}
as an L{IService}. See also L{twisted.application.strports.service} for
constructing one of these directly from a descriptive string.
Additionally, this module (dynamically) defines various Service subclasses that
let you represent clients and servers in a Service hierarchy. Endpoints APIs
should be preferred for stream server services, but since those APIs do not yet
exist for clients or datagram services, many of these are still useful.
They are as follows::
TCPServer, TCPClient,
UNIXServer, UNIXClient,
SSLServer, SSLClient,
UDPServer,
UNIXDatagramServer, UNIXDatagramClient,
MulticastServer
These classes take arbitrary arguments in their constructors and pass
them straight on to their respective reactor.listenXXX or
reactor.connectXXX calls.
For example, the following service starts a web server on port 8080:
C{TCPServer(8080, server.Site(r))}. See the documentation for the
reactor.listen/connect* methods for more information.
"""
from twisted.python import log
from twisted.application import service
from twisted.internet import task
from twisted.internet.defer import CancelledError
def _maybeGlobalReactor(maybeReactor):
"""
@return: the argument, or the global reactor if the argument is C{None}.
"""
if maybeReactor is None:
from twisted.internet import reactor
return reactor
else:
return maybeReactor
class _VolatileDataService(service.Service):
volatile = []
def __getstate__(self):
d = service.Service.__getstate__(self)
for attr in self.volatile:
if attr in d:
del d[attr]
return d
class _AbstractServer(_VolatileDataService):
"""
@cvar volatile: list of attribute to remove from pickling.
@type volatile: C{list}
@ivar method: the type of method to call on the reactor, one of B{TCP},
B{UDP}, B{SSL} or B{UNIX}.
@type method: C{str}
@ivar reactor: the current running reactor.
@type reactor: a provider of C{IReactorTCP}, C{IReactorUDP},
C{IReactorSSL} or C{IReactorUnix}.
@ivar _port: instance of port set when the service is started.
@type _port: a provider of L{twisted.internet.interfaces.IListeningPort}.
"""
volatile = ['_port']
method = None
reactor = None
_port = None
def __init__(self, *args, **kwargs):
self.args = args
if 'reactor' in kwargs:
self.reactor = kwargs.pop("reactor")
self.kwargs = kwargs
def privilegedStartService(self):
service.Service.privilegedStartService(self)
self._port = self._getPort()
def startService(self):
service.Service.startService(self)
if self._port is None:
self._port = self._getPort()
def stopService(self):
service.Service.stopService(self)
# TODO: if startup failed, should shutdown skip stopListening?
# _port won't exist
if self._port is not None:
d = self._port.stopListening()
del self._port
return d
def _getPort(self):
"""
Wrapper around the appropriate listen method of the reactor.
@return: the port object returned by the listen method.
@rtype: an object providing
L{twisted.internet.interfaces.IListeningPort}.
"""
return getattr(_maybeGlobalReactor(self.reactor),
'listen%s' % (self.method,))(*self.args, **self.kwargs)
class _AbstractClient(_VolatileDataService):
"""
@cvar volatile: list of attribute to remove from pickling.
@type volatile: C{list}
@ivar method: the type of method to call on the reactor, one of B{TCP},
B{UDP}, B{SSL} or B{UNIX}.
@type method: C{str}
@ivar reactor: the current running reactor.
@type reactor: a provider of C{IReactorTCP}, C{IReactorUDP},
C{IReactorSSL} or C{IReactorUnix}.
@ivar _connection: instance of connection set when the service is started.
@type _connection: a provider of L{twisted.internet.interfaces.IConnector}.
"""
volatile = ['_connection']
method = None
reactor = None
_connection = None
def __init__(self, *args, **kwargs):
self.args = args
if 'reactor' in kwargs:
self.reactor = kwargs.pop("reactor")
self.kwargs = kwargs
def startService(self):
service.Service.startService(self)
self._connection = self._getConnection()
def stopService(self):
service.Service.stopService(self)
if self._connection is not None:
self._connection.disconnect()
del self._connection
def _getConnection(self):
"""
Wrapper around the appropriate connect method of the reactor.
@return: the port object returned by the connect method.
@rtype: an object providing L{twisted.internet.interfaces.IConnector}.
"""
return getattr(_maybeGlobalReactor(self.reactor),
'connect%s' % (self.method,))(*self.args, **self.kwargs)
_doc={
'Client':
"""Connect to %(tran)s
Call reactor.connect%(tran)s when the service starts, with the
arguments given to the constructor.
""",
'Server':
"""Serve %(tran)s clients
Call reactor.listen%(tran)s when the service starts, with the
arguments given to the constructor. When the service stops,
stop listening. See twisted.internet.interfaces for documentation
on arguments to the reactor method.
""",
}
import types
for tran in 'TCP UNIX SSL UDP UNIXDatagram Multicast'.split():
for side in 'Server Client'.split():
if tran == "Multicast" and side == "Client":
continue
if tran == "UDP" and side == "Client":
continue
base = globals()['_Abstract'+side]
doc = _doc[side] % vars()
klass = types.ClassType(tran+side, (base,),
{'method': tran, '__doc__': doc})
globals()[tran+side] = klass
class TimerService(_VolatileDataService):
"""
Service to periodically call a function
Every C{step} seconds call the given function with the given arguments.
The service starts the calls when it starts, and cancels them
when it stops.
@ivar clock: Source of time. This defaults to L{None} which is
causes L{twisted.internet.reactor} to be used.
Feel free to set this to something else, but it probably ought to be
set *before* calling L{startService}.
@type clock: L{IReactorTime<twisted.internet.interfaces.IReactorTime>}
@ivar call: Function and arguments to call periodically.
@type call: L{tuple} of C{(callable, args, kwargs)}
"""
volatile = ['_loop', '_loopFinished']
def __init__(self, step, callable, *args, **kwargs):
"""
@param step: The number of seconds between calls.
@type step: L{float}
@param callable: Function to call
@type callable: L{callable}
@param args: Positional arguments to pass to function
@param kwargs: Keyword arguments to pass to function
"""
self.step = step
self.call = (callable, args, kwargs)
self.clock = None
def startService(self):
service.Service.startService(self)
callable, args, kwargs = self.call
# we have to make a new LoopingCall each time we're started, because
# an active LoopingCall remains active when serialized. If
# LoopingCall were a _VolatileDataService, we wouldn't need to do
# this.
self._loop = task.LoopingCall(callable, *args, **kwargs)
self._loop.clock = _maybeGlobalReactor(self.clock)
self._loopFinished = self._loop.start(self.step, now=True)
self._loopFinished.addErrback(self._failed)
def _failed(self, why):
# make a note that the LoopingCall is no longer looping, so we don't
# try to shut it down a second time in stopService. I think this
# should be in LoopingCall. -warner
self._loop.running = False
log.err(why)
def stopService(self):
"""
Stop the service.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is fired when the
currently running call (if any) is finished.
"""
if self._loop.running:
self._loop.stop()
self._loopFinished.addCallback(lambda _:
service.Service.stopService(self))
return self._loopFinished
class CooperatorService(service.Service):
"""
Simple L{service.IService} which starts and stops a L{twisted.internet.task.Cooperator}.
"""
def __init__(self):
self.coop = task.Cooperator(started=False)
def coiterate(self, iterator):
return self.coop.coiterate(iterator)
def startService(self):
self.coop.start()
def stopService(self):
self.coop.stop()
class StreamServerEndpointService(service.Service, object):
"""
A L{StreamServerEndpointService} is an L{IService} which runs a server on a
listening port described by an L{IStreamServerEndpoint
<twisted.internet.interfaces.IStreamServerEndpoint>}.
@ivar factory: A server factory which will be used to listen on the
endpoint.
@ivar endpoint: An L{IStreamServerEndpoint
<twisted.internet.interfaces.IStreamServerEndpoint>} provider
which will be used to listen when the service starts.
@ivar _waitingForPort: a Deferred, if C{listen} has yet been invoked on the
endpoint, otherwise None.
@ivar _raiseSynchronously: Defines error-handling behavior for the case
where C{listen(...)} raises an exception before C{startService} or
C{privilegedStartService} have completed.
@type _raiseSynchronously: C{bool}
@since: 10.2
"""
_raiseSynchronously = None
def __init__(self, endpoint, factory):
self.endpoint = endpoint
self.factory = factory
self._waitingForPort = None
def privilegedStartService(self):
"""
Start listening on the endpoint.
"""
service.Service.privilegedStartService(self)
self._waitingForPort = self.endpoint.listen(self.factory)
raisedNow = []
def handleIt(err):
if self._raiseSynchronously:
raisedNow.append(err)
elif not err.check(CancelledError):
log.err(err)
self._waitingForPort.addErrback(handleIt)
if raisedNow:
raisedNow[0].raiseException()
def startService(self):
"""
Start listening on the endpoint, unless L{privilegedStartService} got
around to it already.
"""
service.Service.startService(self)
if self._waitingForPort is None:
self.privilegedStartService()
def stopService(self):
"""
Stop listening on the port if it is already listening, otherwise,
cancel the attempt to listen.
@return: a L{Deferred<twisted.internet.defer.Deferred>} which fires
with C{None} when the port has stopped listening.
"""
self._waitingForPort.cancel()
def stopIt(port):
if port is not None:
return port.stopListening()
d = self._waitingForPort.addCallback(stopIt)
def stop(passthrough):
self.running = False
return passthrough
d.addBoth(stop)
return d
__all__ = (['TimerService', 'CooperatorService', 'MulticastServer',
'StreamServerEndpointService', 'UDPServer'] +
[tran+side
for tran in 'TCP UNIX SSL UNIXDatagram'.split()
for side in 'Server Client'.split()])

View File

@ -0,0 +1,84 @@
# -*- test-case-name: twisted.test.test_application -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Plugin-based system for enumerating available reactors and installing one of
them.
"""
from zope.interface import Interface, Attribute, implementer
from twisted.plugin import IPlugin, getPlugins
from twisted.python.reflect import namedAny
class IReactorInstaller(Interface):
"""
Definition of a reactor which can probably be installed.
"""
shortName = Attribute("""
A brief string giving the user-facing name of this reactor.
""")
description = Attribute("""
A longer string giving a user-facing description of this reactor.
""")
def install():
"""
Install this reactor.
"""
# TODO - A method which provides a best-guess as to whether this reactor
# can actually be used in the execution environment.
class NoSuchReactor(KeyError):
"""
Raised when an attempt is made to install a reactor which cannot be found.
"""
@implementer(IPlugin, IReactorInstaller)
class Reactor(object):
"""
@ivar moduleName: The fully-qualified Python name of the module of which
the install callable is an attribute.
"""
def __init__(self, shortName, moduleName, description):
self.shortName = shortName
self.moduleName = moduleName
self.description = description
def install(self):
namedAny(self.moduleName).install()
def getReactorTypes():
"""
Return an iterator of L{IReactorInstaller} plugins.
"""
return getPlugins(IReactorInstaller)
def installReactor(shortName):
"""
Install the reactor with the given C{shortName} attribute.
@raise NoSuchReactor: If no reactor is found with a matching C{shortName}.
@raise: anything that the specified reactor can raise when installed.
"""
for installer in getReactorTypes():
if installer.shortName == shortName:
installer.install()
from twisted.internet import reactor
return reactor
raise NoSuchReactor(shortName)

View File

@ -0,0 +1,411 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Service architecture for Twisted.
Services are arranged in a hierarchy. At the leafs of the hierarchy,
the services which actually interact with the outside world are started.
Services can be named or anonymous -- usually, they will be named if
there is need to access them through the hierarchy (from a parent or
a sibling).
Maintainer: Moshe Zadka
"""
from zope.interface import implementer, Interface, Attribute
from twisted.python.reflect import namedAny
from twisted.python import components
from twisted.internet import defer
from twisted.persisted import sob
from twisted.plugin import IPlugin
class IServiceMaker(Interface):
"""
An object which can be used to construct services in a flexible
way.
This interface should most often be implemented along with
L{twisted.plugin.IPlugin}, and will most often be used by the
'twistd' command.
"""
tapname = Attribute(
"A short string naming this Twisted plugin, for example 'web' or "
"'pencil'. This name will be used as the subcommand of 'twistd'.")
description = Attribute(
"A brief summary of the features provided by this "
"Twisted application plugin.")
options = Attribute(
"A C{twisted.python.usage.Options} subclass defining the "
"configuration options for this application.")
def makeService(options):
"""
Create and return an object providing
L{twisted.application.service.IService}.
@param options: A mapping (typically a C{dict} or
L{twisted.python.usage.Options} instance) of configuration
options to desired configuration values.
"""
@implementer(IPlugin, IServiceMaker)
class ServiceMaker(object):
"""
Utility class to simplify the definition of L{IServiceMaker} plugins.
"""
def __init__(self, name, module, description, tapname):
self.name = name
self.module = module
self.description = description
self.tapname = tapname
def options():
def get(self):
return namedAny(self.module).Options
return get,
options = property(*options())
def makeService():
def get(self):
return namedAny(self.module).makeService
return get,
makeService = property(*makeService())
class IService(Interface):
"""
A service.
Run start-up and shut-down code at the appropriate times.
@type name: C{string}
@ivar name: The name of the service (or None)
@type running: C{boolean}
@ivar running: Whether the service is running.
"""
def setName(name):
"""
Set the name of the service.
@type name: C{str}
@raise RuntimeError: Raised if the service already has a parent.
"""
def setServiceParent(parent):
"""
Set the parent of the service. This method is responsible for setting
the C{parent} attribute on this service (the child service).
@type parent: L{IServiceCollection}
@raise RuntimeError: Raised if the service already has a parent
or if the service has a name and the parent already has a child
by that name.
"""
def disownServiceParent():
"""
Use this API to remove an L{IService} from an L{IServiceCollection}.
This method is used symmetrically with L{setServiceParent} in that it
sets the C{parent} attribute on the child.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is triggered when the
service has finished shutting down. If shutting down is immediate,
a value can be returned (usually, C{None}).
"""
def startService():
"""
Start the service.
"""
def stopService():
"""
Stop the service.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is triggered when the
service has finished shutting down. If shutting down is immediate,
a value can be returned (usually, C{None}).
"""
def privilegedStartService():
"""
Do preparation work for starting the service.
Here things which should be done before changing directory,
root or shedding privileges are done.
"""
@implementer(IService)
class Service:
"""
Base class for services.
Most services should inherit from this class. It handles the
book-keeping responsibilities of starting and stopping, as well
as not serializing this book-keeping information.
"""
running = 0
name = None
parent = None
def __getstate__(self):
dict = self.__dict__.copy()
if "running" in dict:
del dict['running']
return dict
def setName(self, name):
if self.parent is not None:
raise RuntimeError("cannot change name when parent exists")
self.name = name
def setServiceParent(self, parent):
if self.parent is not None:
self.disownServiceParent()
parent = IServiceCollection(parent, parent)
self.parent = parent
self.parent.addService(self)
def disownServiceParent(self):
d = self.parent.removeService(self)
self.parent = None
return d
def privilegedStartService(self):
pass
def startService(self):
self.running = 1
def stopService(self):
self.running = 0
class IServiceCollection(Interface):
"""
Collection of services.
Contain several services, and manage their start-up/shut-down.
Services can be accessed by name if they have a name, and it
is always possible to iterate over them.
"""
def getServiceNamed(name):
"""
Get the child service with a given name.
@type name: C{str}
@rtype: L{IService}
@raise KeyError: Raised if the service has no child with the
given name.
"""
def __iter__():
"""
Get an iterator over all child services.
"""
def addService(service):
"""
Add a child service.
Only implementations of L{IService.setServiceParent} should use this
method.
@type service: L{IService}
@raise RuntimeError: Raised if the service has a child with
the given name.
"""
def removeService(service):
"""
Remove a child service.
Only implementations of L{IService.disownServiceParent} should
use this method.
@type service: L{IService}
@raise ValueError: Raised if the given service is not a child.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is triggered when the
service has finished shutting down. If shutting down is immediate,
a value can be returned (usually, C{None}).
"""
@implementer(IServiceCollection)
class MultiService(Service):
"""
Straightforward Service Container.
Hold a collection of services, and manage them in a simplistic
way. No service will wait for another, but this object itself
will not finish shutting down until all of its child services
will finish.
"""
def __init__(self):
self.services = []
self.namedServices = {}
self.parent = None
def privilegedStartService(self):
Service.privilegedStartService(self)
for service in self:
service.privilegedStartService()
def startService(self):
Service.startService(self)
for service in self:
service.startService()
def stopService(self):
Service.stopService(self)
l = []
services = list(self)
services.reverse()
for service in services:
l.append(defer.maybeDeferred(service.stopService))
return defer.DeferredList(l)
def getServiceNamed(self, name):
return self.namedServices[name]
def __iter__(self):
return iter(self.services)
def addService(self, service):
if service.name is not None:
if service.name in self.namedServices:
raise RuntimeError("cannot have two services with same name"
" '%s'" % service.name)
self.namedServices[service.name] = service
self.services.append(service)
if self.running:
# It may be too late for that, but we will do our best
service.privilegedStartService()
service.startService()
def removeService(self, service):
if service.name:
del self.namedServices[service.name]
self.services.remove(service)
if self.running:
# Returning this so as not to lose information from the
# MultiService.stopService deferred.
return service.stopService()
else:
return None
class IProcess(Interface):
"""
Process running parameters.
Represents parameters for how processes should be run.
"""
processName = Attribute(
"""
A C{str} giving the name the process should have in ps (or C{None}
to leave the name alone).
""")
uid = Attribute(
"""
An C{int} giving the user id as which the process should run (or
C{None} to leave the UID alone).
""")
gid = Attribute(
"""
An C{int} giving the group id as which the process should run (or
C{None} to leave the GID alone).
""")
@implementer(IProcess)
class Process:
"""
Process running parameters.
Sets up uid/gid in the constructor, and has a default
of C{None} as C{processName}.
"""
processName = None
def __init__(self, uid=None, gid=None):
"""
Set uid and gid.
@param uid: The user ID as whom to execute the process. If
this is C{None}, no attempt will be made to change the UID.
@param gid: The group ID as whom to execute the process. If
this is C{None}, no attempt will be made to change the GID.
"""
self.uid = uid
self.gid = gid
def Application(name, uid=None, gid=None):
"""
Return a compound class.
Return an object supporting the L{IService}, L{IServiceCollection},
L{IProcess} and L{sob.IPersistable} interfaces, with the given
parameters. Always access the return value by explicit casting to
one of the interfaces.
"""
ret = components.Componentized()
for comp in (MultiService(), sob.Persistent(ret, name), Process(uid, gid)):
ret.addComponent(comp, ignoreClass=1)
IService(ret).setName(name)
return ret
def loadApplication(filename, kind, passphrase=None):
"""
Load Application from a given file.
The serialization format it was saved in should be given as
C{kind}, and is one of C{pickle}, C{source}, C{xml} or C{python}. If
C{passphrase} is given, the application was encrypted with the
given passphrase.
@type filename: C{str}
@type kind: C{str}
@type passphrase: C{str}
"""
if kind == 'python':
application = sob.loadValueFromFile(filename, 'application', passphrase)
else:
application = sob.load(filename, kind, passphrase)
return application
__all__ = ['IServiceMaker', 'IService', 'Service',
'IServiceCollection', 'MultiService',
'IProcess', 'Process', 'Application', 'loadApplication']

View File

@ -0,0 +1,103 @@
# -*- test-case-name: twisted.test.test_strports -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Construct listening port services from a simple string description.
@see: L{twisted.internet.endpoints.serverFromString}
@see: L{twisted.internet.endpoints.clientFromString}
"""
import warnings
from twisted.internet import endpoints
from twisted.python.deprecate import deprecatedModuleAttribute
from twisted.python.versions import Version
from twisted.application.internet import StreamServerEndpointService
def parse(description, factory, default='tcp'):
"""
This function is deprecated as of Twisted 10.2.
@see: L{twisted.internet.endpoints.server}
"""
return endpoints._parseServer(description, factory, default)
deprecatedModuleAttribute(
Version("Twisted", 10, 2, 0),
"in favor of twisted.internet.endpoints.serverFromString",
__name__, "parse")
_DEFAULT = object()
def service(description, factory, default=_DEFAULT, reactor=None):
"""
Return the service corresponding to a description.
@param description: The description of the listening port, in the syntax
described by L{twisted.internet.endpoints.server}.
@type description: C{str}
@param factory: The protocol factory which will build protocols for
connections to this service.
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
@type default: C{str} or C{None}
@param default: Do not use this parameter. It has been deprecated since
Twisted 10.2.0.
@rtype: C{twisted.application.service.IService}
@return: the service corresponding to a description of a reliable
stream server.
@see: L{twisted.internet.endpoints.serverFromString}
"""
if reactor is None:
from twisted.internet import reactor
if default is _DEFAULT:
default = None
else:
message = "The 'default' parameter was deprecated in Twisted 10.2.0."
if default is not None:
message += (
" Use qualified endpoint descriptions; for example, "
"'tcp:%s'." % (description,))
warnings.warn(
message=message, category=DeprecationWarning, stacklevel=2)
svc = StreamServerEndpointService(
endpoints._serverFromStringLegacy(reactor, description, default),
factory)
svc._raiseSynchronously = True
return svc
def listen(description, factory, default=None):
"""Listen on a port corresponding to a description
@type description: C{str}
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
@type default: C{str} or C{None}
@rtype: C{twisted.internet.interfaces.IListeningPort}
@return: the port corresponding to a description of a reliable
virtual circuit server.
See the documentation of the C{parse} function for description
of the semantics of the arguments.
"""
from twisted.internet import reactor
name, args, kw = parse(description, factory, default)
return getattr(reactor, 'listen'+name)(*args, **kw)
__all__ = ['parse', 'service', 'listen']

View File

@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.internet.application}.
"""

View File

@ -0,0 +1,403 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for (new code in) L{twisted.application.internet}.
"""
import pickle
from zope.interface import implements
from zope.interface.verify import verifyClass
from twisted.internet.protocol import Factory
from twisted.trial.unittest import TestCase
from twisted.application import internet
from twisted.application.internet import (
StreamServerEndpointService, TimerService)
from twisted.internet.interfaces import IStreamServerEndpoint, IListeningPort
from twisted.internet.defer import Deferred, CancelledError
from twisted.internet import task
from twisted.python.failure import Failure
def fakeTargetFunction():
"""
A fake target function for testing TimerService which does nothing.
"""
pass
class FakeServer(object):
"""
In-memory implementation of L{IStreamServerEndpoint}.
@ivar result: The L{Deferred} resulting from the call to C{listen}, after
C{listen} has been called.
@ivar factory: The factory passed to C{listen}.
@ivar cancelException: The exception to errback C{self.result} when it is
cancelled.
@ivar port: The L{IListeningPort} which C{listen}'s L{Deferred} will fire
with.
@ivar listenAttempts: The number of times C{listen} has been invoked.
@ivar failImmediately: If set, the exception to fail the L{Deferred}
returned from C{listen} before it is returned.
"""
implements(IStreamServerEndpoint)
result = None
factory = None
failImmediately = None
cancelException = CancelledError()
listenAttempts = 0
def __init__(self):
self.port = FakePort()
def listen(self, factory):
"""
Return a Deferred and store it for future use. (Implementation of
L{IStreamServerEndpoint}).
"""
self.listenAttempts += 1
self.factory = factory
self.result = Deferred(
canceller=lambda d: d.errback(self.cancelException))
if self.failImmediately is not None:
self.result.errback(self.failImmediately)
return self.result
def startedListening(self):
"""
Test code should invoke this method after causing C{listen} to be
invoked in order to fire the L{Deferred} previously returned from
C{listen}.
"""
self.result.callback(self.port)
def stoppedListening(self):
"""
Test code should invoke this method after causing C{stopListening} to
be invoked on the port fired from the L{Deferred} returned from
C{listen} in order to cause the L{Deferred} returned from
C{stopListening} to fire.
"""
self.port.deferred.callback(None)
verifyClass(IStreamServerEndpoint, FakeServer)
class FakePort(object):
"""
Fake L{IListeningPort} implementation.
@ivar deferred: The L{Deferred} returned by C{stopListening}.
"""
implements(IListeningPort)
deferred = None
def stopListening(self):
self.deferred = Deferred()
return self.deferred
verifyClass(IStreamServerEndpoint, FakeServer)
class EndpointServiceTests(TestCase):
"""
Tests for L{twisted.application.internet}.
"""
def setUp(self):
"""
Construct a stub server, a stub factory, and a
L{StreamServerEndpointService} to test.
"""
self.fakeServer = FakeServer()
self.factory = Factory()
self.svc = StreamServerEndpointService(self.fakeServer, self.factory)
def test_privilegedStartService(self):
"""
L{StreamServerEndpointService.privilegedStartService} calls its
endpoint's C{listen} method with its factory.
"""
self.svc.privilegedStartService()
self.assertIdentical(self.factory, self.fakeServer.factory)
def test_synchronousRaiseRaisesSynchronously(self, thunk=None):
"""
L{StreamServerEndpointService.startService} should raise synchronously
if the L{Deferred} returned by its wrapped
L{IStreamServerEndpoint.listen} has already fired with an errback and
the L{StreamServerEndpointService}'s C{_raiseSynchronously} flag has
been set. This feature is necessary to preserve compatibility with old
behavior of L{twisted.internet.strports.service}, which is to return a
service which synchronously raises an exception from C{startService}
(so that, among other things, twistd will not start running). However,
since L{IStreamServerEndpoint.listen} may fail asynchronously, it is
a bad idea to rely on this behavior.
"""
self.fakeServer.failImmediately = ZeroDivisionError()
self.svc._raiseSynchronously = True
self.assertRaises(ZeroDivisionError, thunk or self.svc.startService)
def test_synchronousRaisePrivileged(self):
"""
L{StreamServerEndpointService.privilegedStartService} should behave the
same as C{startService} with respect to
L{EndpointServiceTests.test_synchronousRaiseRaisesSynchronously}.
"""
self.test_synchronousRaiseRaisesSynchronously(
self.svc.privilegedStartService)
def test_failReportsError(self):
"""
L{StreamServerEndpointService.startService} and
L{StreamServerEndpointService.privilegedStartService} should both log
an exception when the L{Deferred} returned from their wrapped
L{IStreamServerEndpoint.listen} fails.
"""
self.svc.startService()
self.fakeServer.result.errback(ZeroDivisionError())
logged = self.flushLoggedErrors(ZeroDivisionError)
self.assertEqual(len(logged), 1)
def test_synchronousFailReportsError(self):
"""
Without the C{_raiseSynchronously} compatibility flag, failing
immediately has the same behavior as failing later; it logs the error.
"""
self.fakeServer.failImmediately = ZeroDivisionError()
self.svc.startService()
logged = self.flushLoggedErrors(ZeroDivisionError)
self.assertEqual(len(logged), 1)
def test_startServiceUnstarted(self):
"""
L{StreamServerEndpointService.startService} sets the C{running} flag,
and calls its endpoint's C{listen} method with its factory, if it
has not yet been started.
"""
self.svc.startService()
self.assertIdentical(self.factory, self.fakeServer.factory)
self.assertEqual(self.svc.running, True)
def test_startServiceStarted(self):
"""
L{StreamServerEndpointService.startService} sets the C{running} flag,
but nothing else, if the service has already been started.
"""
self.test_privilegedStartService()
self.svc.startService()
self.assertEqual(self.fakeServer.listenAttempts, 1)
self.assertEqual(self.svc.running, True)
def test_stopService(self):
"""
L{StreamServerEndpointService.stopService} calls C{stopListening} on
the L{IListeningPort} returned from its endpoint, returns the
C{Deferred} from stopService, and sets C{running} to C{False}.
"""
self.svc.privilegedStartService()
self.fakeServer.startedListening()
# Ensure running gets set to true
self.svc.startService()
result = self.svc.stopService()
l = []
result.addCallback(l.append)
self.assertEqual(len(l), 0)
self.fakeServer.stoppedListening()
self.assertEqual(len(l), 1)
self.assertFalse(self.svc.running)
def test_stopServiceBeforeStartFinished(self):
"""
L{StreamServerEndpointService.stopService} cancels the L{Deferred}
returned by C{listen} if it has not yet fired. No error will be logged
about the cancellation of the listen attempt.
"""
self.svc.privilegedStartService()
result = self.svc.stopService()
l = []
result.addBoth(l.append)
self.assertEqual(l, [None])
self.assertEqual(self.flushLoggedErrors(CancelledError), [])
def test_stopServiceCancelStartError(self):
"""
L{StreamServerEndpointService.stopService} cancels the L{Deferred}
returned by C{listen} if it has not fired yet. An error will be logged
if the resulting exception is not L{CancelledError}.
"""
self.fakeServer.cancelException = ZeroDivisionError()
self.svc.privilegedStartService()
result = self.svc.stopService()
l = []
result.addCallback(l.append)
self.assertEqual(l, [None])
stoppingErrors = self.flushLoggedErrors(ZeroDivisionError)
self.assertEqual(len(stoppingErrors), 1)
class TimerServiceTests(TestCase):
"""
Tests for L{twisted.application.internet.TimerService}.
@type timer: L{TimerService}
@ivar timer: service to test
@type clock: L{task.Clock}
@ivar clock: source of time
@type deferred: L{Deferred}
@ivar deferred: deferred returned by L{TimerServiceTests.call}.
"""
def setUp(self):
self.timer = TimerService(2, self.call)
self.clock = self.timer.clock = task.Clock()
self.deferred = Deferred()
def call(self):
"""
Function called by L{TimerService} being tested.
@returns: C{self.deferred}
@rtype: L{Deferred}
"""
return self.deferred
def test_startService(self):
"""
When L{TimerService.startService} is called, it marks itself
as running, creates a L{task.LoopingCall} and starts it.
"""
self.timer.startService()
self.assertTrue(self.timer.running, "Service is started")
self.assertIsInstance(self.timer._loop, task.LoopingCall)
self.assertIdentical(self.clock, self.timer._loop.clock)
self.assertTrue(self.timer._loop.running, "LoopingCall is started")
def test_startServiceRunsCallImmediately(self):
"""
When L{TimerService.startService} is called, it calls the function
immediately.
"""
result = []
self.timer.call = (result.append, (None,), {})
self.timer.startService()
self.assertEqual([None], result)
def test_startServiceUsesGlobalReactor(self):
"""
L{TimerService.startService} uses L{internet._maybeGlobalReactor} to
choose the reactor to pass to L{task.LoopingCall}
uses the global reactor.
"""
otherClock = task.Clock()
def getOtherClock(maybeReactor):
return otherClock
self.patch(internet, "_maybeGlobalReactor", getOtherClock)
self.timer.startService()
self.assertIdentical(otherClock, self.timer._loop.clock)
def test_stopServiceWaits(self):
"""
When L{TimerService.stopService} is called while a call is in progress.
the L{Deferred} returned doesn't fire until after the call finishes.
"""
self.timer.startService()
d = self.timer.stopService()
self.assertNoResult(d)
self.assertEqual(True, self.timer.running)
self.deferred.callback(object())
self.assertIdentical(self.successResultOf(d), None)
def test_stopServiceImmediately(self):
"""
When L{TimerService.stopService} is called while a call isn't in progress.
the L{Deferred} returned has already been fired.
"""
self.timer.startService()
self.deferred.callback(object())
d = self.timer.stopService()
self.assertIdentical(self.successResultOf(d), None)
def test_failedCallLogsError(self):
"""
When function passed to L{TimerService} returns a deferred that errbacks,
the exception is logged, and L{TimerService.stopService} doesn't raise an error.
"""
self.timer.startService()
self.deferred.errback(Failure(ZeroDivisionError()))
errors = self.flushLoggedErrors(ZeroDivisionError)
self.assertEqual(1, len(errors))
d = self.timer.stopService()
self.assertIdentical(self.successResultOf(d), None)
def test_pickleTimerServiceNotPickleLoop(self):
"""
When pickling L{internet.TimerService}, it won't pickle
L{internet.TimerService._loop}.
"""
# We need a pickleable callable to test pickling TimerService. So we
# can't use self.timer
timer = TimerService(1, fakeTargetFunction)
timer.startService()
dumpedTimer = pickle.dumps(timer)
timer.stopService()
loadedTimer = pickle.loads(dumpedTimer)
nothing = object()
value = getattr(loadedTimer, "_loop", nothing)
self.assertIdentical(nothing, value)
def test_pickleTimerServiceNotPickleLoopFinished(self):
"""
When pickling L{internet.TimerService}, it won't pickle
L{internet.TimerService._loopFinished}.
"""
# We need a pickleable callable to test pickling TimerService. So we
# can't use self.timer
timer = TimerService(1, fakeTargetFunction)
timer.startService()
dumpedTimer = pickle.dumps(timer)
timer.stopService()
loadedTimer = pickle.loads(dumpedTimer)
nothing = object()
value = getattr(loadedTimer, "_loopFinished", nothing)
self.assertIdentical(nothing, value)

View File

@ -0,0 +1,10 @@
# -*- test-case-name: twisted.conch.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Conch: The Twisted Shell. Terminal emulation, SSHv2 and telnet.
"""
from twisted.conch._version import version
__version__ = version.short()

View File

@ -0,0 +1,11 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# This is an auto-generated file. Do not edit it.
"""
Provides Twisted version information.
"""
from twisted.python import versions
version = versions.Version('twisted.conch', 15, 2, 1)

View File

@ -0,0 +1,39 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
from zope.interface import implementer
from twisted.conch.error import ConchError
from twisted.conch.interfaces import IConchUser
from twisted.conch.ssh.connection import OPEN_UNKNOWN_CHANNEL_TYPE
from twisted.python import log
@implementer(IConchUser)
class ConchUser:
def __init__(self):
self.channelLookup = {}
self.subsystemLookup = {}
def lookupChannel(self, channelType, windowSize, maxPacket, data):
klass = self.channelLookup.get(channelType, None)
if not klass:
raise ConchError(OPEN_UNKNOWN_CHANNEL_TYPE, "unknown channel")
else:
return klass(remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
data=data, avatar=self)
def lookupSubsystem(self, subsystem, data):
log.msg(repr(self.subsystemLookup))
klass = self.subsystemLookup.get(subsystem, None)
if not klass:
return False
return klass(data, avatar=self)
def gotGlobalRequest(self, requestType, data):
# XXX should this use method dispatch?
requestType = requestType.replace('-', '_')
f = getattr(self, "global_%s" % requestType, None)
if not f:
return 0
return f(data)

View File

@ -0,0 +1,570 @@
# -*- test-case-name: twisted.conch.test.test_checkers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Provide L{ICredentialsChecker} implementations to be used in Conch protocols.
"""
import base64, binascii, errno
try:
import pwd
except ImportError:
pwd = None
else:
import crypt
try:
# Python 2.5 got spwd to interface with shadow passwords
import spwd
except ImportError:
spwd = None
try:
import shadow
except ImportError:
shadow = None
else:
shadow = None
try:
from twisted.cred import pamauth
except ImportError:
pamauth = None
from zope.interface import providedBy, implementer, Interface
from twisted.conch import error
from twisted.conch.ssh import keys
from twisted.cred.checkers import ICredentialsChecker
from twisted.cred.credentials import IUsernamePassword, ISSHPrivateKey
from twisted.cred.error import UnauthorizedLogin, UnhandledCredentials
from twisted.internet import defer
from twisted.python import failure, reflect, log
from twisted.python.deprecate import deprecatedModuleAttribute
from twisted.python.util import runAsEffectiveUser
from twisted.python.filepath import FilePath
from twisted.python.versions import Version
def verifyCryptedPassword(crypted, pw):
return crypt.crypt(pw, crypted) == crypted
def _pwdGetByName(username):
"""
Look up a user in the /etc/passwd database using the pwd module. If the
pwd module is not available, return None.
@param username: the username of the user to return the passwd database
information for.
"""
if pwd is None:
return None
return pwd.getpwnam(username)
def _shadowGetByName(username):
"""
Look up a user in the /etc/shadow database using the spwd or shadow
modules. If neither module is available, return None.
@param username: the username of the user to return the shadow database
information for.
"""
if spwd is not None:
f = spwd.getspnam
elif shadow is not None:
f = shadow.getspnam
else:
return None
return runAsEffectiveUser(0, 0, f, username)
@implementer(ICredentialsChecker)
class UNIXPasswordDatabase:
"""
A checker which validates users out of the UNIX password databases, or
databases of a compatible format.
@ivar _getByNameFunctions: a C{list} of functions which are called in order
to valid a user. The default value is such that the /etc/passwd
database will be tried first, followed by the /etc/shadow database.
"""
credentialInterfaces = IUsernamePassword,
def __init__(self, getByNameFunctions=None):
if getByNameFunctions is None:
getByNameFunctions = [_pwdGetByName, _shadowGetByName]
self._getByNameFunctions = getByNameFunctions
def requestAvatarId(self, credentials):
for func in self._getByNameFunctions:
try:
pwnam = func(credentials.username)
except KeyError:
return defer.fail(UnauthorizedLogin("invalid username"))
else:
if pwnam is not None:
crypted = pwnam[1]
if crypted == '':
continue
if verifyCryptedPassword(crypted, credentials.password):
return defer.succeed(credentials.username)
# fallback
return defer.fail(UnauthorizedLogin("unable to verify password"))
@implementer(ICredentialsChecker)
class SSHPublicKeyDatabase:
"""
Checker that authenticates SSH public keys, based on public keys listed in
authorized_keys and authorized_keys2 files in user .ssh/ directories.
"""
credentialInterfaces = (ISSHPrivateKey,)
_userdb = pwd
def requestAvatarId(self, credentials):
d = defer.maybeDeferred(self.checkKey, credentials)
d.addCallback(self._cbRequestAvatarId, credentials)
d.addErrback(self._ebRequestAvatarId)
return d
def _cbRequestAvatarId(self, validKey, credentials):
"""
Check whether the credentials themselves are valid, now that we know
if the key matches the user.
@param validKey: A boolean indicating whether or not the public key
matches a key in the user's authorized_keys file.
@param credentials: The credentials offered by the user.
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: (as a failure) if the key does not match the
user in C{credentials}. Also raised if the user provides an invalid
signature.
@raise ValidPublicKey: (as a failure) if the key matches the user but
the credentials do not include a signature. See
L{error.ValidPublicKey} for more information.
@return: The user's username, if authentication was successful.
"""
if not validKey:
return failure.Failure(UnauthorizedLogin("invalid key"))
if not credentials.signature:
return failure.Failure(error.ValidPublicKey())
else:
try:
pubKey = keys.Key.fromString(credentials.blob)
if pubKey.verify(credentials.signature, credentials.sigData):
return credentials.username
except: # any error should be treated as a failed login
log.err()
return failure.Failure(UnauthorizedLogin('error while verifying key'))
return failure.Failure(UnauthorizedLogin("unable to verify key"))
def getAuthorizedKeysFiles(self, credentials):
"""
Return a list of L{FilePath} instances for I{authorized_keys} files
which might contain information about authorized keys for the given
credentials.
On OpenSSH servers, the default location of the file containing the
list of authorized public keys is
U{$HOME/.ssh/authorized_keys<http://www.openbsd.org/cgi-bin/man.cgi?query=sshd_config>}.
I{$HOME/.ssh/authorized_keys2} is also returned, though it has been
U{deprecated by OpenSSH since
2001<http://marc.info/?m=100508718416162>}.
@return: A list of L{FilePath} instances to files with the authorized keys.
"""
pwent = self._userdb.getpwnam(credentials.username)
root = FilePath(pwent.pw_dir).child('.ssh')
files = ['authorized_keys', 'authorized_keys2']
return [root.child(f) for f in files]
def checkKey(self, credentials):
"""
Retrieve files containing authorized keys and check against user
credentials.
"""
ouid, ogid = self._userdb.getpwnam(credentials.username)[2:4]
for filepath in self.getAuthorizedKeysFiles(credentials):
if not filepath.exists():
continue
try:
lines = filepath.open()
except IOError, e:
if e.errno == errno.EACCES:
lines = runAsEffectiveUser(ouid, ogid, filepath.open)
else:
raise
for l in lines:
l2 = l.split()
if len(l2) < 2:
continue
try:
if base64.decodestring(l2[1]) == credentials.blob:
return True
except binascii.Error:
continue
return False
def _ebRequestAvatarId(self, f):
if not f.check(UnauthorizedLogin):
log.msg(f)
return failure.Failure(UnauthorizedLogin("unable to get avatar id"))
return f
@implementer(ICredentialsChecker)
class SSHProtocolChecker:
"""
SSHProtocolChecker is a checker that requires multiple authentications
to succeed. To add a checker, call my registerChecker method with
the checker and the interface.
After each successful authenticate, I call my areDone method with the
avatar id. To get a list of the successful credentials for an avatar id,
use C{SSHProcotolChecker.successfulCredentials[avatarId]}. If L{areDone}
returns True, the authentication has succeeded.
"""
def __init__(self):
self.checkers = {}
self.successfulCredentials = {}
def get_credentialInterfaces(self):
return self.checkers.keys()
credentialInterfaces = property(get_credentialInterfaces)
def registerChecker(self, checker, *credentialInterfaces):
if not credentialInterfaces:
credentialInterfaces = checker.credentialInterfaces
for credentialInterface in credentialInterfaces:
self.checkers[credentialInterface] = checker
def requestAvatarId(self, credentials):
"""
Part of the L{ICredentialsChecker} interface. Called by a portal with
some credentials to check if they'll authenticate a user. We check the
interfaces that the credentials provide against our list of acceptable
checkers. If one of them matches, we ask that checker to verify the
credentials. If they're valid, we call our L{_cbGoodAuthentication}
method to continue.
@param credentials: the credentials the L{Portal} wants us to verify
"""
ifac = providedBy(credentials)
for i in ifac:
c = self.checkers.get(i)
if c is not None:
d = defer.maybeDeferred(c.requestAvatarId, credentials)
return d.addCallback(self._cbGoodAuthentication,
credentials)
return defer.fail(UnhandledCredentials("No checker for %s" % \
', '.join(map(reflect.qual, ifac))))
def _cbGoodAuthentication(self, avatarId, credentials):
"""
Called if a checker has verified the credentials. We call our
L{areDone} method to see if the whole of the successful authentications
are enough. If they are, we return the avatar ID returned by the first
checker.
"""
if avatarId not in self.successfulCredentials:
self.successfulCredentials[avatarId] = []
self.successfulCredentials[avatarId].append(credentials)
if self.areDone(avatarId):
del self.successfulCredentials[avatarId]
return avatarId
else:
raise error.NotEnoughAuthentication()
def areDone(self, avatarId):
"""
Override to determine if the authentication is finished for a given
avatarId.
@param avatarId: the avatar returned by the first checker. For
this checker to function correctly, all the checkers must
return the same avatar ID.
"""
return True
deprecatedModuleAttribute(
Version("Twisted", 15, 0, 0),
("Please use twisted.conch.checkers.SSHPublicKeyChecker, "
"initialized with an instance of "
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead."),
__name__, "SSHPublicKeyDatabase")
class IAuthorizedKeysDB(Interface):
"""
An object that provides valid authorized ssh keys mapped to usernames.
@since: 15.0
"""
def getAuthorizedKeys(avatarId):
"""
Gets an iterable of authorized keys that are valid for the given
C{avatarId}.
@param avatarId: the ID of the avatar
@type avatarId: valid return value of
L{twisted.cred.checkers.ICredentialsChecker.requestAvatarId}
@return: an iterable of L{twisted.conch.ssh.keys.Key}
"""
def readAuthorizedKeyFile(fileobj, parseKey=keys.Key.fromString):
"""
Reads keys from an authorized keys file. Any non-comment line that cannot
be parsed as a key will be ignored, although that particular line will
be logged.
@param fileobj: something from which to read lines which can be parsed
as keys
@type fileobj: L{file}-like object
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
default is L{twisted.conch.ssh.keys.Key.fromString}.
@type parseKey: L{callable}
@return: an iterable of L{twisted.conch.ssh.keys.Key}
@rtype: iterable
@since: 15.0
"""
for line in fileobj:
line = line.strip()
if line and not line.startswith('#'): # for comments
try:
yield parseKey(line)
except keys.BadKeyError as e:
log.msg('Unable to parse line "{0}" as a key: {1!s}'
.format(line, e))
def _keysFromFilepaths(filepaths, parseKey):
"""
Helper function that turns an iterable of filepaths into a generator of
keys. If any file cannot be read, a message is logged but it is
otherwise ignored.
@param filepaths: iterable of L{twisted.python.filepath.FilePath}.
@type filepaths: iterable
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}
@type parseKey: L{callable}
@return: generator of L{twisted.conch.ssh.keys.Key}
@rtype: generator
@since: 15.0
"""
for fp in filepaths:
if fp.exists():
try:
with fp.open() as f:
for key in readAuthorizedKeyFile(f, parseKey):
yield key
except (IOError, OSError) as e:
log.msg("Unable to read {0}: {1!s}".format(fp.path, e))
@implementer(IAuthorizedKeysDB)
class InMemorySSHKeyDB(object):
"""
Object that provides SSH public keys based on a dictionary of usernames
mapped to L{twisted.conch.ssh.keys.Key}s.
@since: 15.0
"""
def __init__(self, mapping):
"""
Initializes a new L{InMemorySSHKeyDB}.
@param mapping: mapping of usernames to iterables of
L{twisted.conch.ssh.keys.Key}s
@type mapping: C{dict}
"""
self._mapping = mapping
def getAuthorizedKeys(self, username):
return self._mapping.get(username, [])
@implementer(IAuthorizedKeysDB)
class UNIXAuthorizedKeysFiles(object):
"""
Object that provides SSH public keys based on public keys listed in
authorized_keys and authorized_keys2 files in UNIX user .ssh/ directories.
If any of the files cannot be read, a message is logged but that file is
otherwise ignored.
@since: 15.0
"""
def __init__(self, userdb=None, parseKey=keys.Key.fromString):
"""
Initializes a new L{UNIXAuthorizedKeysFiles}.
@param userdb: access to the Unix user account and password database
(default is the Python module L{pwd})
@type userdb: L{pwd}-like object
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
default is L{twisted.conch.ssh.keys.Key.fromString}.
@type parseKey: L{callable}
"""
self._userdb = userdb
self._parseKey = parseKey
if userdb is None:
self._userdb = pwd
def getAuthorizedKeys(self, username):
try:
passwd = self._userdb.getpwnam(username)
except KeyError:
return ()
root = FilePath(passwd.pw_dir).child('.ssh')
files = ['authorized_keys', 'authorized_keys2']
return _keysFromFilepaths((root.child(f) for f in files),
self._parseKey)
@implementer(ICredentialsChecker)
class SSHPublicKeyChecker(object):
"""
Checker that authenticates SSH public keys, based on public keys listed in
authorized_keys and authorized_keys2 files in user .ssh/ directories.
Initializing this checker with a L{UNIXAuthorizedKeysFiles} should be
used instead of L{twisted.conch.checkers.SSHPublicKeyDatabase}.
@since: 15.0
"""
credentialInterfaces = (ISSHPrivateKey,)
def __init__(self, keydb):
"""
Initializes a L{SSHPublicKeyChecker}.
@param keydb: a provider of L{IAuthorizedKeysDB}
@type keydb: L{IAuthorizedKeysDB} provider
"""
self._keydb = keydb
def requestAvatarId(self, credentials):
d = defer.maybeDeferred(self._sanityCheckKey, credentials)
d.addCallback(self._checkKey, credentials)
d.addCallback(self._verifyKey, credentials)
return d
def _sanityCheckKey(self, credentials):
"""
Checks whether the provided credentials are a valid SSH key with a
signature (does not actually verify the signature).
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise ValidPublicKey: the credentials do not include a signature. See
L{error.ValidPublicKey} for more information.
@raise BadKeyError: The key included with the credentials is not
recognized as a key.
@return: the key in the credentials
@rtype: L{twisted.conch.ssh.keys.Key}
"""
if not credentials.signature:
raise error.ValidPublicKey()
return keys.Key.fromString(credentials.blob)
def _checkKey(self, pubKey, credentials):
"""
Checks the public key against all authorized keys (if any) for the
user.
@param pubKey: the key in the credentials (just to prevent it from
having to be calculated again)
@type pubKey:
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: If the key is not authorized, or if there
was any error obtaining a list of authorized keys for the user.
@return: C{pubKey} if the key is authorized
@rtype: L{twisted.conch.ssh.keys.Key}
"""
if any(key == pubKey for key in
self._keydb.getAuthorizedKeys(credentials.username)):
return pubKey
raise UnauthorizedLogin("Key not authorized")
def _verifyKey(self, pubKey, credentials):
"""
Checks whether the credentials themselves are valid, now that we know
if the key matches the user.
@param pubKey: the key in the credentials (just to prevent it from
having to be calculated again)
@type pubKey: L{twisted.conch.ssh.keys.Key}
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: If the key signature is invalid or there
was any error verifying the signature.
@return: The user's username, if authentication was successful
@rtype: C{str}
"""
try:
if pubKey.verify(credentials.signature, credentials.sigData):
return credentials.username
except: # any error should be treated as a failed login
log.err()
raise UnauthorizedLogin('Error while verifying key')
raise UnauthorizedLogin("Key signature invalid.")

View File

@ -0,0 +1,9 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Client support code for Conch.
Maintainer: Paul Swartz
"""

View File

@ -0,0 +1,73 @@
# -*- test-case-name: twisted.conch.test.test_default -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Accesses the key agent for user authentication.
Maintainer: Paul Swartz
"""
import os
from twisted.conch.ssh import agent, channel, keys
from twisted.internet import protocol, reactor
from twisted.python import log
class SSHAgentClient(agent.SSHAgentClient):
def __init__(self):
agent.SSHAgentClient.__init__(self)
self.blobs = []
def getPublicKeys(self):
return self.requestIdentities().addCallback(self._cbPublicKeys)
def _cbPublicKeys(self, blobcomm):
log.msg('got %i public keys' % len(blobcomm))
self.blobs = [x[0] for x in blobcomm]
def getPublicKey(self):
"""
Return a L{Key} from the first blob in C{self.blobs}, if any, or
return C{None}.
"""
if self.blobs:
return keys.Key.fromString(self.blobs.pop(0))
return None
class SSHAgentForwardingChannel(channel.SSHChannel):
def channelOpen(self, specificData):
cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal)
d = cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
d.addCallback(self._cbGotLocal)
d.addErrback(lambda x:self.loseConnection())
self.buf = ''
def _cbGotLocal(self, local):
self.local = local
self.dataReceived = self.local.transport.write
self.local.dataReceived = self.write
def dataReceived(self, data):
self.buf += data
def closed(self):
if self.local:
self.local.loseConnection()
self.local = None
class SSHAgentForwardingLocal(protocol.Protocol):
pass

View File

@ -0,0 +1,21 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
import direct
connectTypes = {"direct" : direct.connect}
def connect(host, port, options, verifyHostKey, userAuthObject):
useConnects = ['direct']
return _ebConnect(None, useConnects, host, port, options, verifyHostKey,
userAuthObject)
def _ebConnect(f, useConnects, host, port, options, vhk, uao):
if not useConnects:
return f
connectType = useConnects.pop(0)
f = connectTypes[connectType]
d = f(host, port, options, vhk, uao)
d.addErrback(_ebConnect, useConnects, host, port, options, vhk, uao)
return d

View File

@ -0,0 +1,260 @@
# -*- test-case-name: twisted.conch.test.test_knownhosts,twisted.conch.test.test_default -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Various classes and functions for implementing user-interaction in the
command-line conch client.
You probably shouldn't use anything in this module directly, since it assumes
you are sitting at an interactive terminal. For example, to programmatically
interact with a known_hosts database, use L{twisted.conch.client.knownhosts}.
"""
from twisted.python import log
from twisted.python.filepath import FilePath
from twisted.conch.error import ConchError
from twisted.conch.ssh import common, keys, userauth
from twisted.internet import defer, protocol, reactor
from twisted.conch.client.knownhosts import KnownHostsFile, ConsoleUI
from twisted.conch.client import agent
import os, sys, base64, getpass
# The default location of the known hosts file (probably should be parsed out
# of an ssh config file someday).
_KNOWN_HOSTS = "~/.ssh/known_hosts"
# This name is bound so that the unit tests can use 'patch' to override it.
_open = open
def verifyHostKey(transport, host, pubKey, fingerprint):
"""
Verify a host's key.
This function is a gross vestige of some bad factoring in the client
internals. The actual implementation, and a better signature of this logic
is in L{KnownHostsFile.verifyHostKey}. This function is not deprecated yet
because the callers have not yet been rehabilitated, but they should
eventually be changed to call that method instead.
However, this function does perform two functions not implemented by
L{KnownHostsFile.verifyHostKey}. It determines the path to the user's
known_hosts file based on the options (which should really be the options
object's job), and it provides an opener to L{ConsoleUI} which opens
'/dev/tty' so that the user will be prompted on the tty of the process even
if the input and output of the process has been redirected. This latter
part is, somewhat obviously, not portable, but I don't know of a portable
equivalent that could be used.
@param host: Due to a bug in L{SSHClientTransport.verifyHostKey}, this is
always the dotted-quad IP address of the host being connected to.
@type host: L{str}
@param transport: the client transport which is attempting to connect to
the given host.
@type transport: L{SSHClientTransport}
@param fingerprint: the fingerprint of the given public key, in
xx:xx:xx:... format. This is ignored in favor of getting the fingerprint
from the key itself.
@type fingerprint: L{str}
@param pubKey: The public key of the server being connected to.
@type pubKey: L{str}
@return: a L{Deferred} which fires with C{1} if the key was successfully
verified, or fails if the key could not be successfully verified. Failure
types may include L{HostKeyChanged}, L{UserRejectedKey}, L{IOError} or
L{KeyboardInterrupt}.
"""
actualHost = transport.factory.options['host']
actualKey = keys.Key.fromString(pubKey)
kh = KnownHostsFile.fromPath(FilePath(
transport.factory.options['known-hosts']
or os.path.expanduser(_KNOWN_HOSTS)
))
ui = ConsoleUI(lambda : _open("/dev/tty", "r+b"))
return kh.verifyHostKey(ui, actualHost, host, actualKey)
def isInKnownHosts(host, pubKey, options):
"""checks to see if host is in the known_hosts file for the user.
returns 0 if it isn't, 1 if it is and is the same, 2 if it's changed.
"""
keyType = common.getNS(pubKey)[0]
retVal = 0
if not options['known-hosts'] and not os.path.exists(os.path.expanduser('~/.ssh/')):
print 'Creating ~/.ssh directory...'
os.mkdir(os.path.expanduser('~/.ssh'))
kh_file = options['known-hosts'] or _KNOWN_HOSTS
try:
known_hosts = open(os.path.expanduser(kh_file))
except IOError:
return 0
for line in known_hosts.xreadlines():
split = line.split()
if len(split) < 3:
continue
hosts, hostKeyType, encodedKey = split[:3]
if host not in hosts.split(','): # incorrect host
continue
if hostKeyType != keyType: # incorrect type of key
continue
try:
decodedKey = base64.decodestring(encodedKey)
except:
continue
if decodedKey == pubKey:
return 1
else:
retVal = 2
return retVal
class SSHUserAuthClient(userauth.SSHUserAuthClient):
def __init__(self, user, options, *args):
userauth.SSHUserAuthClient.__init__(self, user, *args)
self.keyAgent = None
self.options = options
self.usedFiles = []
if not options.identitys:
options.identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
def serviceStarted(self):
if 'SSH_AUTH_SOCK' in os.environ and not self.options['noagent']:
log.msg('using agent')
cc = protocol.ClientCreator(reactor, agent.SSHAgentClient)
d = cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
d.addCallback(self._setAgent)
d.addErrback(self._ebSetAgent)
else:
userauth.SSHUserAuthClient.serviceStarted(self)
def serviceStopped(self):
if self.keyAgent:
self.keyAgent.transport.loseConnection()
self.keyAgent = None
def _setAgent(self, a):
self.keyAgent = a
d = self.keyAgent.getPublicKeys()
d.addBoth(self._ebSetAgent)
return d
def _ebSetAgent(self, f):
userauth.SSHUserAuthClient.serviceStarted(self)
def _getPassword(self, prompt):
try:
oldout, oldin = sys.stdout, sys.stdin
sys.stdin = sys.stdout = open('/dev/tty','r+')
p=getpass.getpass(prompt)
sys.stdout,sys.stdin=oldout,oldin
return p
except (KeyboardInterrupt, IOError):
print
raise ConchError('PEBKAC')
def getPassword(self, prompt = None):
if not prompt:
prompt = "%s@%s's password: " % (self.user, self.transport.transport.getPeer().host)
try:
p = self._getPassword(prompt)
return defer.succeed(p)
except ConchError:
return defer.fail()
def getPublicKey(self):
"""
Get a public key from the key agent if possible, otherwise look in
the next configured identity file for one.
"""
if self.keyAgent:
key = self.keyAgent.getPublicKey()
if key is not None:
return key
files = [x for x in self.options.identitys if x not in self.usedFiles]
log.msg(str(self.options.identitys))
log.msg(str(files))
if not files:
return None
file = files[0]
log.msg(file)
self.usedFiles.append(file)
file = os.path.expanduser(file)
file += '.pub'
if not os.path.exists(file):
return self.getPublicKey() # try again
try:
return keys.Key.fromFile(file)
except keys.BadKeyError:
return self.getPublicKey() # try again
def signData(self, publicKey, signData):
"""
Extend the base signing behavior by using an SSH agent to sign the
data, if one is available.
@type publicKey: L{Key}
@type signData: C{str}
"""
if not self.usedFiles: # agent key
return self.keyAgent.signData(publicKey.blob(), signData)
else:
return userauth.SSHUserAuthClient.signData(self, publicKey, signData)
def getPrivateKey(self):
"""
Try to load the private key from the last used file identified by
C{getPublicKey}, potentially asking for the passphrase if the key is
encrypted.
"""
file = os.path.expanduser(self.usedFiles[-1])
if not os.path.exists(file):
return None
try:
return defer.succeed(keys.Key.fromFile(file))
except keys.EncryptedKeyError:
for i in range(3):
prompt = "Enter passphrase for key '%s': " % \
self.usedFiles[-1]
try:
p = self._getPassword(prompt)
return defer.succeed(keys.Key.fromFile(file, passphrase=p))
except (keys.BadKeyError, ConchError):
pass
return defer.fail(ConchError('bad password'))
raise
except KeyboardInterrupt:
print
reactor.stop()
def getGenericAnswers(self, name, instruction, prompts):
responses = []
try:
oldout, oldin = sys.stdout, sys.stdin
sys.stdin = sys.stdout = open('/dev/tty','r+')
if name:
print name
if instruction:
print instruction
for prompt, echo in prompts:
if echo:
responses.append(raw_input(prompt))
else:
responses.append(getpass.getpass(prompt))
finally:
sys.stdout,sys.stdin=oldout,oldin
return defer.succeed(responses)

View File

@ -0,0 +1,107 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.internet import defer, protocol, reactor
from twisted.conch import error
from twisted.conch.ssh import transport
from twisted.python import log
class SSHClientFactory(protocol.ClientFactory):
def __init__(self, d, options, verifyHostKey, userAuthObject):
self.d = d
self.options = options
self.verifyHostKey = verifyHostKey
self.userAuthObject = userAuthObject
def clientConnectionLost(self, connector, reason):
if self.options['reconnect']:
connector.connect()
def clientConnectionFailed(self, connector, reason):
if self.d is None:
return
d, self.d = self.d, None
d.errback(reason)
def buildProtocol(self, addr):
trans = SSHClientTransport(self)
if self.options['ciphers']:
trans.supportedCiphers = self.options['ciphers']
if self.options['macs']:
trans.supportedMACs = self.options['macs']
if self.options['compress']:
trans.supportedCompressions[0:1] = ['zlib']
if self.options['host-key-algorithms']:
trans.supportedPublicKeys = self.options['host-key-algorithms']
return trans
class SSHClientTransport(transport.SSHClientTransport):
def __init__(self, factory):
self.factory = factory
self.unixServer = None
def connectionLost(self, reason):
if self.unixServer:
d = self.unixServer.stopListening()
self.unixServer = None
else:
d = defer.succeed(None)
d.addCallback(lambda x:
transport.SSHClientTransport.connectionLost(self, reason))
def receiveError(self, code, desc):
if self.factory.d is None:
return
d, self.factory.d = self.factory.d, None
d.errback(error.ConchError(desc, code))
def sendDisconnect(self, code, reason):
if self.factory.d is None:
return
d, self.factory.d = self.factory.d, None
transport.SSHClientTransport.sendDisconnect(self, code, reason)
d.errback(error.ConchError(reason, code))
def receiveDebug(self, alwaysDisplay, message, lang):
log.msg('Received Debug Message: %s' % message)
if alwaysDisplay: # XXX what should happen here?
print message
def verifyHostKey(self, pubKey, fingerprint):
return self.factory.verifyHostKey(self, self.transport.getPeer().host, pubKey,
fingerprint)
def setService(self, service):
log.msg('setting client server to %s' % service)
transport.SSHClientTransport.setService(self, service)
if service.name != 'ssh-userauth' and self.factory.d is not None:
d, self.factory.d = self.factory.d, None
d.callback(None)
def connectionSecure(self):
self.requestService(self.factory.userAuthObject)
def connect(host, port, options, verifyHostKey, userAuthObject):
d = defer.Deferred()
factory = SSHClientFactory(d, options, verifyHostKey, userAuthObject)
reactor.connectTCP(host, port, factory)
return d

View File

@ -0,0 +1,621 @@
# -*- test-case-name: twisted.conch.test.test_knownhosts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An implementation of the OpenSSH known_hosts database.
@since: 8.2
"""
import hmac
from binascii import Error as DecodeError, b2a_base64
from hashlib import sha1
from zope.interface import implementer
from twisted.python.randbytes import secureRandom
from twisted.internet import defer
from twisted.python import log
from twisted.python.util import FancyEqMixin
from twisted.conch.interfaces import IKnownHostEntry
from twisted.conch.error import HostKeyChanged, UserRejectedKey, InvalidEntry
from twisted.conch.ssh.keys import Key, BadKeyError
def _b64encode(s):
"""
Encode a binary string as base64 with no trailing newline.
@param s: The string to encode.
@type s: L{bytes}
@return: The base64-encoded string.
@rtype: L{bytes}
"""
return b2a_base64(s).strip()
def _extractCommon(string):
"""
Extract common elements of base64 keys from an entry in a hosts file.
@param string: A known hosts file entry (a single line).
@type string: L{bytes}
@return: a 4-tuple of hostname data (L{bytes}), ssh key type (L{bytes}), key
(L{Key}), and comment (L{bytes} or L{None}). The hostname data is
simply the beginning of the line up to the first occurrence of
whitespace.
@rtype: L{tuple}
"""
elements = string.split(None, 2)
if len(elements) != 3:
raise InvalidEntry()
hostnames, keyType, keyAndComment = elements
splitkey = keyAndComment.split(None, 1)
if len(splitkey) == 2:
keyString, comment = splitkey
comment = comment.rstrip("\n")
else:
keyString = splitkey[0]
comment = None
key = Key.fromString(keyString.decode('base64'))
return hostnames, keyType, key, comment
class _BaseEntry(object):
"""
Abstract base of both hashed and non-hashed entry objects, since they
represent keys and key types the same way.
@ivar keyType: The type of the key; either ssh-dss or ssh-rsa.
@type keyType: L{str}
@ivar publicKey: The server public key indicated by this line.
@type publicKey: L{twisted.conch.ssh.keys.Key}
@ivar comment: Trailing garbage after the key line.
@type comment: L{str}
"""
def __init__(self, keyType, publicKey, comment):
self.keyType = keyType
self.publicKey = publicKey
self.comment = comment
def matchesKey(self, keyObject):
"""
Check to see if this entry matches a given key object.
@param keyObject: A public key object to check.
@type keyObject: L{Key}
@return: C{True} if this entry's key matches C{keyObject}, C{False}
otherwise.
@rtype: L{bool}
"""
return self.publicKey == keyObject
@implementer(IKnownHostEntry)
class PlainEntry(_BaseEntry):
"""
A L{PlainEntry} is a representation of a plain-text entry in a known_hosts
file.
@ivar _hostnames: the list of all host-names associated with this entry.
@type _hostnames: L{list} of L{str}
"""
def __init__(self, hostnames, keyType, publicKey, comment):
self._hostnames = hostnames
super(PlainEntry, self).__init__(keyType, publicKey, comment)
def fromString(cls, string):
"""
Parse a plain-text entry in a known_hosts file, and return a
corresponding L{PlainEntry}.
@param string: a space-separated string formatted like "hostname
key-type base64-key-data comment".
@type string: L{str}
@raise DecodeError: if the key is not valid encoded as valid base64.
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
@return: an IKnownHostEntry representing the hostname and key in the
input line.
@rtype: L{PlainEntry}
"""
hostnames, keyType, key, comment = _extractCommon(string)
self = cls(hostnames.split(","), keyType, key, comment)
return self
fromString = classmethod(fromString)
def matchesHost(self, hostname):
"""
Check to see if this entry matches a given hostname.
@param hostname: A hostname or IP address literal to check against this
entry.
@type hostname: L{str}
@return: C{True} if this entry is for the given hostname or IP address,
C{False} otherwise.
@rtype: L{bool}
"""
return hostname in self._hostnames
def toString(self):
"""
Implement L{IKnownHostEntry.toString} by recording the comma-separated
hostnames, key type, and base-64 encoded key.
@return: The string representation of this entry, with unhashed hostname
information.
@rtype: L{bytes}
"""
fields = [','.join(self._hostnames),
self.keyType,
_b64encode(self.publicKey.blob())]
if self.comment is not None:
fields.append(self.comment)
return ' '.join(fields)
@implementer(IKnownHostEntry)
class UnparsedEntry(object):
"""
L{UnparsedEntry} is an entry in a L{KnownHostsFile} which can't actually be
parsed; therefore it matches no keys and no hosts.
"""
def __init__(self, string):
"""
Create an unparsed entry from a line in a known_hosts file which cannot
otherwise be parsed.
"""
self._string = string
def matchesHost(self, hostname):
"""
Always returns False.
"""
return False
def matchesKey(self, key):
"""
Always returns False.
"""
return False
def toString(self):
"""
Returns the input line, without its newline if one was given.
@return: The string representation of this entry, almost exactly as was
used to initialize this entry but without a trailing newline.
@rtype: L{bytes}
"""
return self._string.rstrip("\n")
def _hmacedString(key, string):
"""
Return the SHA-1 HMAC hash of the given key and string.
@param key: The HMAC key.
@type key: L{bytes}
@param string: The string to be hashed.
@type string: L{bytes}
@return: The keyed hash value.
@rtype: L{bytes}
"""
hash = hmac.HMAC(key, digestmod=sha1)
hash.update(string)
return hash.digest()
@implementer(IKnownHostEntry)
class HashedEntry(_BaseEntry, FancyEqMixin):
"""
A L{HashedEntry} is a representation of an entry in a known_hosts file
where the hostname has been hashed and salted.
@ivar _hostSalt: the salt to combine with a hostname for hashing.
@ivar _hostHash: the hashed representation of the hostname.
@cvar MAGIC: the 'hash magic' string used to identify a hashed line in a
known_hosts file as opposed to a plaintext one.
"""
MAGIC = '|1|'
compareAttributes = (
"_hostSalt", "_hostHash", "keyType", "publicKey", "comment")
def __init__(self, hostSalt, hostHash, keyType, publicKey, comment):
self._hostSalt = hostSalt
self._hostHash = hostHash
super(HashedEntry, self).__init__(keyType, publicKey, comment)
def fromString(cls, string):
"""
Load a hashed entry from a string representing a line in a known_hosts
file.
@param string: A complete single line from a I{known_hosts} file,
formatted as defined by OpenSSH.
@type string: L{bytes}
@raise DecodeError: if the key, the hostname, or the is not valid
encoded as valid base64
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid, or the host/hash portion contains
more items than just the host and hash.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
@return: The newly created L{HashedEntry} instance, initialized with the
information from C{string}.
"""
stuff, keyType, key, comment = _extractCommon(string)
saltAndHash = stuff[len(cls.MAGIC):].split("|")
if len(saltAndHash) != 2:
raise InvalidEntry()
hostSalt, hostHash = saltAndHash
self = cls(hostSalt.decode("base64"), hostHash.decode("base64"),
keyType, key, comment)
return self
fromString = classmethod(fromString)
def matchesHost(self, hostname):
"""
Implement L{IKnownHostEntry.matchesHost} to compare the hash of the
input to the stored hash.
@param hostname: A hostname or IP address literal to check against this
entry.
@type hostname: L{bytes}
@return: C{True} if this entry is for the given hostname or IP address,
C{False} otherwise.
@rtype: L{bool}
"""
return (_hmacedString(self._hostSalt, hostname) == self._hostHash)
def toString(self):
"""
Implement L{IKnownHostEntry.toString} by base64-encoding the salt, host
hash, and key.
@return: The string representation of this entry, with the hostname part
hashed.
@rtype: L{bytes}
"""
fields = [self.MAGIC + '|'.join([_b64encode(self._hostSalt),
_b64encode(self._hostHash)]),
self.keyType,
_b64encode(self.publicKey.blob())]
if self.comment is not None:
fields.append(self.comment)
return ' '.join(fields)
class KnownHostsFile(object):
"""
A structured representation of an OpenSSH-format ~/.ssh/known_hosts file.
@ivar _added: A list of L{IKnownHostEntry} providers which have been added
to this instance in memory but not yet saved.
@ivar _clobber: A flag indicating whether the current contents of the save
path will be disregarded and potentially overwritten or not. If
C{True}, this will be done. If C{False}, entries in the save path will
be read and new entries will be saved by appending rather than
overwriting.
@type _clobber: L{bool}
@ivar _savePath: See C{savePath} parameter of L{__init__}.
"""
def __init__(self, savePath):
"""
Create a new, empty KnownHostsFile.
Unless you want to erase the current contents of C{savePath}, you want
to use L{KnownHostsFile.fromPath} instead.
@param savePath: The L{FilePath} to which to save new entries.
@type savePath: L{FilePath}
"""
self._added = []
self._savePath = savePath
self._clobber = True
@property
def savePath(self):
"""
@see: C{savePath} parameter of L{__init__}
"""
return self._savePath
def iterentries(self):
"""
Iterate over the host entries in this file.
@return: An iterable the elements of which provide L{IKnownHostEntry}.
There is an element for each entry in the file as well as an element
for each added but not yet saved entry.
@rtype: iterable of L{IKnownHostEntry} providers
"""
for entry in self._added:
yield entry
if self._clobber:
return
try:
fp = self._savePath.open()
except IOError:
return
try:
for line in fp:
try:
if line.startswith(HashedEntry.MAGIC):
entry = HashedEntry.fromString(line)
else:
entry = PlainEntry.fromString(line)
except (DecodeError, InvalidEntry, BadKeyError):
entry = UnparsedEntry(line)
yield entry
finally:
fp.close()
def hasHostKey(self, hostname, key):
"""
Check for an entry with matching hostname and key.
@param hostname: A hostname or IP address literal to check for.
@type hostname: L{bytes}
@param key: The public key to check for.
@type key: L{Key}
@return: C{True} if the given hostname and key are present in this file,
C{False} if they are not.
@rtype: L{bool}
@raise HostKeyChanged: if the host key found for the given hostname
does not match the given key.
"""
for lineidx, entry in enumerate(self.iterentries(), -len(self._added)):
if entry.matchesHost(hostname):
if entry.matchesKey(key):
return True
else:
# Notice that lineidx is 0-based but HostKeyChanged.lineno
# is 1-based.
if lineidx < 0:
line = None
path = None
else:
line = lineidx + 1
path = self._savePath
raise HostKeyChanged(entry, path, line)
return False
def verifyHostKey(self, ui, hostname, ip, key):
"""
Verify the given host key for the given IP and host, asking for
confirmation from, and notifying, the given UI about changes to this
file.
@param ui: The user interface to request an IP address from.
@param hostname: The hostname that the user requested to connect to.
@param ip: The string representation of the IP address that is actually
being connected to.
@param key: The public key of the server.
@return: a L{Deferred} that fires with True when the key has been
verified, or fires with an errback when the key either cannot be
verified or has changed.
@rtype: L{Deferred}
"""
hhk = defer.maybeDeferred(self.hasHostKey, hostname, key)
def gotHasKey(result):
if result:
if not self.hasHostKey(ip, key):
ui.warn("Warning: Permanently added the %s host key for "
"IP address '%s' to the list of known hosts." %
(key.type(), ip))
self.addHostKey(ip, key)
self.save()
return result
else:
def promptResponse(response):
if response:
self.addHostKey(hostname, key)
self.addHostKey(ip, key)
self.save()
return response
else:
raise UserRejectedKey()
proceed = ui.prompt(
"The authenticity of host '%s (%s)' "
"can't be established.\n"
"RSA key fingerprint is %s.\n"
"Are you sure you want to continue connecting (yes/no)? " %
(hostname, ip, key.fingerprint()))
return proceed.addCallback(promptResponse)
return hhk.addCallback(gotHasKey)
def addHostKey(self, hostname, key):
"""
Add a new L{HashedEntry} to the key database.
Note that you still need to call L{KnownHostsFile.save} if you wish
these changes to be persisted.
@param hostname: A hostname or IP address literal to associate with the
new entry.
@type hostname: L{bytes}
@param key: The public key to associate with the new entry.
@type key: L{Key}
@return: The L{HashedEntry} that was added.
@rtype: L{HashedEntry}
"""
salt = secureRandom(20)
keyType = "ssh-" + key.type().lower()
entry = HashedEntry(salt, _hmacedString(salt, hostname),
keyType, key, None)
self._added.append(entry)
return entry
def save(self):
"""
Save this L{KnownHostsFile} to the path it was loaded from.
"""
p = self._savePath.parent()
if not p.isdir():
p.makedirs()
if self._clobber:
mode = "w"
else:
mode = "a"
with self._savePath.open(mode) as hostsFileObj:
if self._added:
hostsFileObj.write(
"\n".join([entry.toString() for entry in self._added]) +
"\n")
self._added = []
self._clobber = False
def fromPath(cls, path):
"""
Create a new L{KnownHostsFile}, potentially reading existing known
hosts information from the given file.
@param path: A path object to use for both reading contents from and
later saving to. If no file exists at this path, it is not an
error; a L{KnownHostsFile} with no entries is returned.
@type path: L{FilePath}
@return: A L{KnownHostsFile} initialized with entries from C{path}.
@rtype: L{KnownHostsFile}
"""
knownHosts = cls(path)
knownHosts._clobber = False
return knownHosts
fromPath = classmethod(fromPath)
class ConsoleUI(object):
"""
A UI object that can ask true/false questions and post notifications on the
console, to be used during key verification.
"""
def __init__(self, opener):
"""
@param opener: A no-argument callable which should open a console
binary-mode file-like object to be used for reading and writing.
This initializes the C{opener} attribute.
@type opener: callable taking no arguments and returning a read/write
file-like object
"""
self.opener = opener
def prompt(self, text):
"""
Write the given text as a prompt to the console output, then read a
result from the console input.
@param text: Something to present to a user to solicit a yes or no
response.
@type text: L{bytes}
@return: a L{Deferred} which fires with L{True} when the user answers
'yes' and L{False} when the user answers 'no'. It may errback if
there were any I/O errors.
"""
d = defer.succeed(None)
def body(ignored):
f = self.opener()
f.write(text)
while True:
answer = f.readline().strip().lower()
if answer == 'yes':
f.close()
return True
elif answer == 'no':
f.close()
return False
else:
f.write("Please type 'yes' or 'no': ")
return d.addCallback(body)
def warn(self, text):
"""
Notify the user (non-interactively) of the provided text, by writing it
to the console.
@param text: Some information the user is to be made aware of.
@type text: L{bytes}
"""
try:
f = self.opener()
f.write(text)
f.close()
except:
log.err()

View File

@ -0,0 +1,96 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
from twisted.conch.ssh.transport import SSHClientTransport, SSHCiphers
from twisted.python import usage
import sys
class ConchOptions(usage.Options):
optParameters = [['user', 'l', None, 'Log in using this user name.'],
['identity', 'i', None],
['ciphers', 'c', None],
['macs', 'm', None],
['port', 'p', None, 'Connect to this port. Server must be on the same port.'],
['option', 'o', None, 'Ignored OpenSSH options'],
['host-key-algorithms', '', None],
['known-hosts', '', None, 'File to check for host keys'],
['user-authentications', '', None, 'Types of user authentications to use.'],
['logfile', '', None, 'File to log to, or - for stdout'],
]
optFlags = [['version', 'V', 'Display version number only.'],
['compress', 'C', 'Enable compression.'],
['log', 'v', 'Enable logging (defaults to stderr)'],
['nox11', 'x', 'Disable X11 connection forwarding (default)'],
['agent', 'A', 'Enable authentication agent forwarding'],
['noagent', 'a', 'Disable authentication agent forwarding (default)'],
['reconnect', 'r', 'Reconnect to the server if the connection is lost.'],
]
compData = usage.Completions(
mutuallyExclusive=[("agent", "noagent")],
optActions={
"user": usage.CompleteUsernames(),
"ciphers": usage.CompleteMultiList(
SSHCiphers.cipherMap.keys(),
descr='ciphers to choose from'),
"macs": usage.CompleteMultiList(
SSHCiphers.macMap.keys(),
descr='macs to choose from'),
"host-key-algorithms": usage.CompleteMultiList(
SSHClientTransport.supportedPublicKeys,
descr='host key algorithms to choose from'),
#"user-authentications": usage.CompleteMultiList(?
# descr='user authentication types' ),
},
extraActions=[usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr='argument',
repeat=True)]
)
def __init__(self, *args, **kw):
usage.Options.__init__(self, *args, **kw)
self.identitys = []
self.conns = None
def opt_identity(self, i):
"""Identity for public-key authentication"""
self.identitys.append(i)
def opt_ciphers(self, ciphers):
"Select encryption algorithms"
ciphers = ciphers.split(',')
for cipher in ciphers:
if not SSHCiphers.cipherMap.has_key(cipher):
sys.exit("Unknown cipher type '%s'" % cipher)
self['ciphers'] = ciphers
def opt_macs(self, macs):
"Specify MAC algorithms"
macs = macs.split(',')
for mac in macs:
if not SSHCiphers.macMap.has_key(mac):
sys.exit("Unknown mac type '%s'" % mac)
self['macs'] = macs
def opt_host_key_algorithms(self, hkas):
"Select host key algorithms"
hkas = hkas.split(',')
for hka in hkas:
if hka not in SSHClientTransport.supportedPublicKeys:
sys.exit("Unknown host key type '%s'" % hka)
self['host-key-algorithms'] = hkas
def opt_user_authentications(self, uas):
"Choose how to authenticate to the remote server"
self['user-authentications'] = uas.split(',')
# def opt_compress(self):
# "Enable compression"
# self.enableCompression = 1
# SSHClientTransport.supportedCompressions[0:1] = ['zlib']

View File

@ -0,0 +1,832 @@
# -*- test-case-name: twisted.conch.test.test_endpoints -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Endpoint implementations of various SSH interactions.
"""
__all__ = [
'AuthenticationFailed', 'SSHCommandAddress', 'SSHCommandClientEndpoint']
from struct import unpack
from os.path import expanduser
from zope.interface import Interface, implementer
from twisted.python.filepath import FilePath
from twisted.python.failure import Failure
from twisted.internet.error import ConnectionDone, ProcessTerminated
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.internet.protocol import Factory
from twisted.internet.defer import Deferred, succeed, CancelledError
from twisted.internet.endpoints import TCP4ClientEndpoint, connectProtocol
from twisted.conch.ssh.keys import Key
from twisted.conch.ssh.common import NS
from twisted.conch.ssh.transport import SSHClientTransport
from twisted.conch.ssh.connection import SSHConnection
from twisted.conch.ssh.userauth import SSHUserAuthClient
from twisted.conch.ssh.channel import SSHChannel
from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
from twisted.conch.client.agent import SSHAgentClient
from twisted.conch.client.default import _KNOWN_HOSTS
class AuthenticationFailed(Exception):
"""
An SSH session could not be established because authentication was not
successful.
"""
# This should be public. See #6541.
class _ISSHConnectionCreator(Interface):
"""
An L{_ISSHConnectionCreator} knows how to create SSH connections somehow.
"""
def secureConnection():
"""
Return a new, connected, secured, but not yet authenticated instance of
L{twisted.conch.ssh.transport.SSHServerTransport} or
L{twisted.conch.ssh.transport.SSHClientTransport}.
"""
def cleanupConnection(connection, immediate):
"""
Perform cleanup necessary for a connection object previously returned
from this creator's C{secureConnection} method.
@param connection: An L{twisted.conch.ssh.transport.SSHServerTransport}
or L{twisted.conch.ssh.transport.SSHClientTransport} returned by a
previous call to C{secureConnection}. It is no longer needed by the
caller of that method and may be closed or otherwise cleaned up as
necessary.
@param immediate: If C{True} don't wait for any network communication,
just close the connection immediately and as aggressively as
necessary.
"""
class SSHCommandAddress(object):
"""
An L{SSHCommandAddress} instance represents the address of an SSH server, a
username which was used to authenticate with that server, and a command
which was run there.
@ivar server: See L{__init__}
@ivar username: See L{__init__}
@ivar command: See L{__init__}
"""
def __init__(self, server, username, command):
"""
@param server: The address of the SSH server on which the command is
running.
@type server: L{IAddress} provider
@param username: An authentication username which was used to
authenticate against the server at the given address.
@type username: L{bytes}
@param command: A command which was run in a session channel on the
server at the given address.
@type command: L{bytes}
"""
self.server = server
self.username = username
self.command = command
class _CommandChannel(SSHChannel):
"""
A L{_CommandChannel} executes a command in a session channel and connects
its input and output to an L{IProtocol} provider.
@ivar _creator: See L{__init__}
@ivar _command: See L{__init__}
@ivar _protocolFactory: See L{__init__}
@ivar _commandConnected: See L{__init__}
@ivar _protocol: An L{IProtocol} provider created using C{_protocolFactory}
which is hooked up to the running command's input and output streams.
"""
name = b'session'
def __init__(self, creator, command, protocolFactory, commandConnected):
"""
@param creator: The L{_ISSHConnectionCreator} provider which was used
to get the connection which this channel exists on.
@type creator: L{_ISSHConnectionCreator} provider
@param command: The command to be executed.
@type command: L{bytes}
@param protocolFactory: A client factory to use to build a L{IProtocol}
provider to use to associate with the running command.
@param commandConnected: A L{Deferred} to use to signal that execution
of the command has failed or that it has succeeded and the command
is now running.
@type commandConnected: L{Deferred}
"""
SSHChannel.__init__(self)
self._creator = creator
self._command = command
self._protocolFactory = protocolFactory
self._commandConnected = commandConnected
self._reason = None
def openFailed(self, reason):
"""
When the request to open a new channel to run this command in fails,
fire the C{commandConnected} deferred with a failure indicating that.
"""
self._commandConnected.errback(reason)
def channelOpen(self, ignored):
"""
When the request to open a new channel to run this command in succeeds,
issue an C{"exec"} request to run the command.
"""
command = self.conn.sendRequest(
self, 'exec', NS(self._command), wantReply=True)
command.addCallbacks(self._execSuccess, self._execFailure)
def _execFailure(self, reason):
"""
When the request to execute the command in this channel fails, fire the
C{commandConnected} deferred with a failure indicating this.
@param reason: The cause of the command execution failure.
@type reason: L{Failure}
"""
self._commandConnected.errback(reason)
def _execSuccess(self, ignored):
"""
When the request to execute the command in this channel succeeds, use
C{protocolFactory} to build a protocol to handle the command's input and
output and connect the protocol to a transport representing those
streams.
Also fire C{commandConnected} with the created protocol after it is
connected to its transport.
@param ignored: The (ignored) result of the execute request
"""
self._protocol = self._protocolFactory.buildProtocol(
SSHCommandAddress(
self.conn.transport.transport.getPeer(),
self.conn.transport.creator.username,
self.conn.transport.creator.command))
self._protocol.makeConnection(self)
self._commandConnected.callback(self._protocol)
def dataReceived(self, data):
"""
When the command's stdout data arrives over the channel, deliver it to
the protocol instance.
@param data: The bytes from the command's stdout.
@type data: L{bytes}
"""
self._protocol.dataReceived(data)
def request_exit_status(self, data):
"""
When the server sends the command's exit status, record it for later
delivery to the protocol.
@param data: The network-order four byte representation of the exit
status of the command.
@type data: L{bytes}
"""
(status,) = unpack('>L', data)
if status != 0:
self._reason = ProcessTerminated(status, None, None)
def request_exit_signal(self, data):
"""
When the server sends the command's exit status, record it for later
delivery to the protocol.
@param data: The network-order four byte representation of the exit
signal of the command.
@type data: L{bytes}
"""
(signal,) = unpack('>L', data)
self._reason = ProcessTerminated(None, signal, None)
def closed(self):
"""
When the channel closes, deliver disconnection notification to the
protocol.
"""
self._creator.cleanupConnection(self.conn, False)
if self._reason is None:
reason = ConnectionDone("ssh channel closed")
else:
reason = self._reason
self._protocol.connectionLost(Failure(reason))
class _ConnectionReady(SSHConnection):
"""
L{_ConnectionReady} is an L{SSHConnection} (an SSH service) which only
propagates the I{serviceStarted} event to a L{Deferred} to be handled
elsewhere.
"""
def __init__(self, ready):
"""
@param ready: A L{Deferred} which should be fired when
I{serviceStarted} happens.
"""
SSHConnection.__init__(self)
self._ready = ready
def serviceStarted(self):
"""
When the SSH I{connection} I{service} this object represents is ready to
be used, fire the C{connectionReady} L{Deferred} to publish that event
to some other interested party.
"""
self._ready.callback(self)
del self._ready
class _UserAuth(SSHUserAuthClient):
"""
L{_UserAuth} implements the client part of SSH user authentication in the
convenient way a user might expect if they are familiar with the
interactive I{ssh} command line client.
L{_UserAuth} supports key-based authentication, password-based
authentication, and delegating authentication to an agent.
"""
password = None
keys = None
agent = None
def getPublicKey(self):
"""
Retrieve the next public key object to offer to the server, possibly
delegating to an authentication agent if there is one.
@return: The public part of a key pair that could be used to
authenticate with the server, or C{None} if there are no more public
keys to try.
@rtype: L{twisted.conch.ssh.keys.Key} or L{types.NoneType}
"""
if self.agent is not None:
return self.agent.getPublicKey()
if self.keys:
self.key = self.keys.pop(0)
else:
self.key = None
return self.key.public()
def signData(self, publicKey, signData):
"""
Extend the base signing behavior by using an SSH agent to sign the
data, if one is available.
@type publicKey: L{Key}
@type signData: C{str}
"""
if self.agent is not None:
return self.agent.signData(publicKey.blob(), signData)
else:
return SSHUserAuthClient.signData(self, publicKey, signData)
def getPrivateKey(self):
"""
Get the private part of a key pair to use for authentication. The key
corresponds to the public part most recently returned from
C{getPublicKey}.
@return: A L{Deferred} which fires with the private key.
@rtype: L{Deferred}
"""
return succeed(self.key)
def getPassword(self):
"""
Get the password to use for authentication.
@return: A L{Deferred} which fires with the password, or C{None} if the
password was not specified.
"""
if self.password is None:
return
return succeed(self.password)
def ssh_USERAUTH_SUCCESS(self, packet):
"""
Handle user authentication success in the normal way, but also make a
note of the state change on the L{_CommandTransport}.
"""
self.transport._state = b'CHANNELLING'
return SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
class _CommandTransport(SSHClientTransport):
"""
L{_CommandTransport} is an SSH client I{transport} which includes a host key
verification step before it will proceed to secure the connection.
L{_CommandTransport} also knows how to set up a connection to an
authentication agent if it is told where it can connect to one.
"""
# STARTING -> SECURING -> AUTHENTICATING -> CHANNELLING -> RUNNING
_state = b'STARTING'
_hostKeyFailure = None
def __init__(self, creator):
"""
@param creator: The L{_NewConnectionHelper} that created this
connection.
@type creator: L{_NewConnectionHelper}.
"""
self.connectionReady = Deferred(
lambda d: self.transport.abortConnection())
# Clear the reference to that deferred to help the garbage collector
# and to signal to other parts of this implementation (in particular
# connectionLost) that it has already been fired and does not need to
# be fired again.
def readyFired(result):
self.connectionReady = None
return result
self.connectionReady.addBoth(readyFired)
self.creator = creator
def verifyHostKey(self, hostKey, fingerprint):
"""
Ask the L{KnownHostsFile} provider available on the factory which
created this protocol this protocol to verify the given host key.
@return: A L{Deferred} which fires with the result of
L{KnownHostsFile.verifyHostKey}.
"""
hostname = self.creator.hostname
ip = self.transport.getPeer().host
self._state = b'SECURING'
d = self.creator.knownHosts.verifyHostKey(
self.creator.ui, hostname, ip, Key.fromString(hostKey))
d.addErrback(self._saveHostKeyFailure)
return d
def _saveHostKeyFailure(self, reason):
"""
When host key verification fails, record the reason for the failure in
order to fire a L{Deferred} with it later.
@param reason: The cause of the host key verification failure.
@type reason: L{Failure}
@return: C{reason}
@rtype: L{Failure}
"""
self._hostKeyFailure = reason
return reason
def connectionSecure(self):
"""
When the connection is secure, start the authentication process.
"""
self._state = b'AUTHENTICATING'
command = _ConnectionReady(self.connectionReady)
userauth = _UserAuth(self.creator.username, command)
userauth.password = self.creator.password
if self.creator.keys:
userauth.keys = list(self.creator.keys)
if self.creator.agentEndpoint is not None:
d = self._connectToAgent(userauth, self.creator.agentEndpoint)
else:
d = succeed(None)
def maybeGotAgent(ignored):
self.requestService(userauth)
d.addBoth(maybeGotAgent)
def _connectToAgent(self, userauth, endpoint):
"""
Set up a connection to the authentication agent and trigger its
initialization.
@param userauth: The L{_UserAuth} instance which is in charge of the
overall authentication process.
@type userauth: L{_UserAuth}
@param endpoint: An endpoint which can be used to connect to the
authentication agent.
@type endpoint: L{IStreamClientEndpoint} provider
@return: A L{Deferred} which fires when the agent connection is ready
for use.
"""
factory = Factory()
factory.protocol = SSHAgentClient
d = endpoint.connect(factory)
def connected(agent):
userauth.agent = agent
return agent.getPublicKeys()
d.addCallback(connected)
return d
def connectionLost(self, reason):
"""
When the underlying connection to the SSH server is lost, if there were
any connection setup errors, propagate them.
"""
if self._state == b'RUNNING' or self.connectionReady is None:
return
if self._state == b'SECURING' and self._hostKeyFailure is not None:
reason = self._hostKeyFailure
elif self._state == b'AUTHENTICATING':
reason = Failure(
AuthenticationFailed("Connection lost while authenticating"))
self.connectionReady.errback(reason)
@implementer(IStreamClientEndpoint)
class SSHCommandClientEndpoint(object):
"""
L{SSHCommandClientEndpoint} exposes the command-executing functionality of
SSH servers.
L{SSHCommandClientEndpoint} can set up a new SSH connection, authenticate
it in any one of a number of different ways (keys, passwords, agents),
launch a command over that connection and then associate its input and
output with a protocol.
It can also re-use an existing, already-authenticated SSH connection
(perhaps one which already has some SSH channels being used for other
purposes). In this case it creates a new SSH channel to use to execute the
command. Notably this means it supports multiplexing several different
command invocations over a single SSH connection.
"""
def __init__(self, creator, command):
"""
@param creator: An L{_ISSHConnectionCreator} provider which will be
used to set up the SSH connection which will be used to run a
command.
@type creator: L{_ISSHConnectionCreator} provider
@param command: The command line to execute on the SSH server. This
byte string is interpreted by a shell on the SSH server, so it may
have a value like C{"ls /"}. Take care when trying to run a command
like C{"/Volumes/My Stuff/a-program"} - spaces (and other special
bytes) may require escaping.
@type command: L{bytes}
"""
self._creator = creator
self._command = command
@classmethod
def newConnection(cls, reactor, command, username, hostname, port=None,
keys=None, password=None, agentEndpoint=None,
knownHosts=None, ui=None):
"""
Create and return a new endpoint which will try to create a new
connection to an SSH server and run a command over it. It will also
close the connection if there are problems leading up to the command
being executed, after the command finishes, or if the connection
L{Deferred} is cancelled.
@param reactor: The reactor to use to establish the connection.
@type reactor: L{IReactorTCP} provider
@param command: See L{__init__}'s C{command} argument.
@param username: The username with which to authenticate to the SSH
server.
@type username: L{bytes}
@param hostname: The hostname of the SSH server.
@type hostname: L{bytes}
@param port: The port number of the SSH server. By default, the
standard SSH port number is used.
@type port: L{int}
@param keys: Private keys with which to authenticate to the SSH server,
if key authentication is to be attempted (otherwise C{None}).
@type keys: L{list} of L{Key}
@param password: The password with which to authenticate to the SSH
server, if password authentication is to be attempted (otherwise
C{None}).
@type password: L{bytes} or L{types.NoneType}
@param agentEndpoint: An L{IStreamClientEndpoint} provider which may be
used to connect to an SSH agent, if one is to be used to help with
authentication.
@type agentEndpoint: L{IStreamClientEndpoint} provider
@param knownHosts: The currently known host keys, used to check the
host key presented by the server we actually connect to.
@type knownHosts: L{KnownHostsFile}
@param ui: An object for interacting with users to make decisions about
whether to accept the server host keys. If C{None}, a L{ConsoleUI}
connected to /dev/tty will be used; if /dev/tty is unavailable, an
object which answers C{b"no"} to all prompts will be used.
@type ui: L{NoneType} or L{ConsoleUI}
@return: A new instance of C{cls} (probably
L{SSHCommandClientEndpoint}).
"""
helper = _NewConnectionHelper(
reactor, hostname, port, command, username, keys, password,
agentEndpoint, knownHosts, ui)
return cls(helper, command)
@classmethod
def existingConnection(cls, connection, command):
"""
Create and return a new endpoint which will try to open a new channel on
an existing SSH connection and run a command over it. It will B{not}
close the connection if there is a problem executing the command or
after the command finishes.
@param connection: An existing connection to an SSH server.
@type connection: L{SSHConnection}
@param command: See L{SSHCommandClientEndpoint.newConnection}'s
C{command} parameter.
@type command: L{bytes}
@return: A new instance of C{cls} (probably
L{SSHCommandClientEndpoint}).
"""
helper = _ExistingConnectionHelper(connection)
return cls(helper, command)
def connect(self, protocolFactory):
"""
Set up an SSH connection, use a channel from that connection to launch
a command, and hook the stdin and stdout of that command up as a
transport for a protocol created by the given factory.
@param protocolFactory: A L{Factory} to use to create the protocol
which will be connected to the stdin and stdout of the command on
the SSH server.
@return: A L{Deferred} which will fire with an error if the connection
cannot be set up for any reason or with the protocol instance
created by C{protocolFactory} once it has been connected to the
command.
"""
d = self._creator.secureConnection()
d.addCallback(self._executeCommand, protocolFactory)
return d
def _executeCommand(self, connection, protocolFactory):
"""
Given a secured SSH connection, try to execute a command in a new
channel created on it and associate the result with a protocol from the
given factory.
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
C{connection} parameter.
@param protocolFactory: See L{SSHCommandClientEndpoint.connect}'s
C{protocolFactory} parameter.
@return: See L{SSHCommandClientEndpoint.connect}'s return value.
"""
commandConnected = Deferred()
def disconnectOnFailure(passthrough):
# Close the connection immediately in case of cancellation, since
# that implies user wants it gone immediately (e.g. a timeout):
immediate = passthrough.check(CancelledError)
self._creator.cleanupConnection(connection, immediate)
return passthrough
commandConnected.addErrback(disconnectOnFailure)
channel = _CommandChannel(
self._creator, self._command, protocolFactory, commandConnected)
connection.openChannel(channel)
return commandConnected
class _ReadFile(object):
"""
A weakly file-like object which can be used with L{KnownHostsFile} to
respond in the negative to all prompts for decisions.
"""
def __init__(self, contents):
"""
@param contents: L{bytes} which will be returned from every C{readline}
call.
"""
self._contents = contents
def write(self, data):
"""
No-op.
@param data: ignored
"""
def readline(self, count=-1):
"""
Always give back the byte string that this L{_ReadFile} was initialized
with.
@param count: ignored
@return: A fixed byte-string.
@rtype: L{bytes}
"""
return self._contents
def close(self):
"""
No-op.
"""
@implementer(_ISSHConnectionCreator)
class _NewConnectionHelper(object):
"""
L{_NewConnectionHelper} implements L{_ISSHConnectionCreator} by
establishing a brand new SSH connection, securing it, and authenticating.
"""
_KNOWN_HOSTS = _KNOWN_HOSTS
port = 22
def __init__(self, reactor, hostname, port, command, username, keys,
password, agentEndpoint, knownHosts, ui,
tty=FilePath(b"/dev/tty")):
"""
@param tty: The path of the tty device to use in case C{ui} is C{None}.
@type tty: L{FilePath}
@see: L{SSHCommandClientEndpoint.newConnection}
"""
self.reactor = reactor
self.hostname = hostname
if port is not None:
self.port = port
self.command = command
self.username = username
self.keys = keys
self.password = password
self.agentEndpoint = agentEndpoint
if knownHosts is None:
knownHosts = self._knownHosts()
self.knownHosts = knownHosts
if ui is None:
ui = ConsoleUI(self._opener)
self.ui = ui
self.tty = tty
def _opener(self):
"""
Open the tty if possible, otherwise give back a file-like object from
which C{b"no"} can be read.
For use as the opener argument to L{ConsoleUI}.
"""
try:
return self.tty.open("r+")
except:
# Give back a file-like object from which can be read a byte string
# that KnownHostsFile recognizes as rejecting some option (b"no").
return _ReadFile(b"no")
@classmethod
def _knownHosts(cls):
"""
@return: A L{KnownHostsFile} instance pointed at the user's personal
I{known hosts} file.
@type: L{KnownHostsFile}
"""
return KnownHostsFile.fromPath(FilePath(expanduser(cls._KNOWN_HOSTS)))
def secureConnection(self):
"""
Create and return a new SSH connection which has been secured and on
which authentication has already happened.
@return: A L{Deferred} which fires with the ready-to-use connection or
with a failure if something prevents the connection from being
setup, secured, or authenticated.
"""
protocol = _CommandTransport(self)
ready = protocol.connectionReady
sshClient = TCP4ClientEndpoint(self.reactor, self.hostname, self.port)
d = connectProtocol(sshClient, protocol)
d.addCallback(lambda ignored: ready)
return d
def cleanupConnection(self, connection, immediate):
"""
Clean up the connection by closing it. The command running on the
endpoint has ended so the connection is no longer needed.
@param connection: The L{SSHConnection} to close.
@type connection: L{SSHConnection}
@param immediate: Whether to close connection immediately.
@type immediate: L{bool}.
"""
if immediate:
# We're assuming the underlying connection is a ITCPTransport,
# which is what the current implementation is restricted to:
connection.transport.transport.abortConnection()
else:
connection.transport.loseConnection()
@implementer(_ISSHConnectionCreator)
class _ExistingConnectionHelper(object):
"""
L{_ExistingConnectionHelper} implements L{_ISSHConnectionCreator} by
handing out an existing SSH connection which is supplied to its
initializer.
"""
def __init__(self, connection):
"""
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
C{connection} parameter.
"""
self.connection = connection
def secureConnection(self):
"""
@return: A L{Deferred} that fires synchronously with the
already-established connection object.
"""
return succeed(self.connection)
def cleanupConnection(self, connection, immediate):
"""
Do not do any cleanup on the connection. Leave that responsibility to
whatever code created it in the first place.
@param connection: The L{SSHConnection} which will not be modified in
any way.
@type connection: L{SSHConnection}
@param immediate: An argument which will be ignored.
@type immediate: L{bool}.
"""

View File

@ -0,0 +1,102 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An error to represent bad things happening in Conch.
Maintainer: Paul Swartz
"""
from twisted.cred.error import UnauthorizedLogin
class ConchError(Exception):
def __init__(self, value, data = None):
Exception.__init__(self, value, data)
self.value = value
self.data = data
class NotEnoughAuthentication(Exception):
"""
This is thrown if the authentication is valid, but is not enough to
successfully verify the user. i.e. don't retry this type of
authentication, try another one.
"""
class ValidPublicKey(UnauthorizedLogin):
"""
Raised by public key checkers when they receive public key credentials
that don't contain a signature at all, but are valid in every other way.
(e.g. the public key matches one in the user's authorized_keys file).
Protocol code (eg
L{SSHUserAuthServer<twisted.conch.ssh.userauth.SSHUserAuthServer>}) which
attempts to log in using
L{ISSHPrivateKey<twisted.cred.credentials.ISSHPrivateKey>} credentials
should be prepared to handle a failure of this type by telling the user to
re-authenticate using the same key and to include a signature with the new
attempt.
See U{http://www.ietf.org/rfc/rfc4252.txt} section 7 for more details.
"""
class IgnoreAuthentication(Exception):
"""
This is thrown to let the UserAuthServer know it doesn't need to handle the
authentication anymore.
"""
class MissingKeyStoreError(Exception):
"""
Raised if an SSHAgentServer starts receiving data without its factory
providing a keys dict on which to read/write key data.
"""
class UserRejectedKey(Exception):
"""
The user interactively rejected a key.
"""
class InvalidEntry(Exception):
"""
An entry in a known_hosts file could not be interpreted as a valid entry.
"""
class HostKeyChanged(Exception):
"""
The host key of a remote host has changed.
@ivar offendingEntry: The entry which contains the persistent host key that
disagrees with the given host key.
@type offendingEntry: L{twisted.conch.interfaces.IKnownHostEntry}
@ivar path: a reference to the known_hosts file that the offending entry
was loaded from
@type path: L{twisted.python.filepath.FilePath}
@ivar lineno: The line number of the offending entry in the given path.
@type lineno: L{int}
"""
def __init__(self, offendingEntry, path, lineno):
Exception.__init__(self)
self.offendingEntry = offendingEntry
self.path = path
self.lineno = lineno

View File

@ -0,0 +1,16 @@
"""
Insults: a replacement for Curses/S-Lang.
Very basic at the moment."""
from twisted.python import deprecate, versions
deprecate.deprecatedModuleAttribute(
versions.Version("Twisted", 10, 1, 0),
"Please use twisted.conch.insults.helper instead.",
__name__, "colors")
deprecate.deprecatedModuleAttribute(
versions.Version("Twisted", 10, 1, 0),
"Please use twisted.conch.insults.insults instead.",
__name__, "client")

View File

@ -0,0 +1,138 @@
"""
You don't really want to use this module. Try insults.py instead.
"""
from twisted.internet import protocol
class InsultsClient(protocol.Protocol):
escapeTimeout = 0.2
def __init__(self):
self.width = self.height = None
self.xpos = self.ypos = 0
self.commandQueue = []
self.inEscape = ''
def setSize(self, width, height):
call = 0
if self.width:
call = 1
self.width = width
self.height = height
if call:
self.windowSizeChanged()
def dataReceived(self, data):
from twisted.internet import reactor
for ch in data:
if ch == '\x1b':
if self.inEscape:
self.keyReceived(ch)
self.inEscape = ''
else:
self.inEscape = ch
self.escapeCall = reactor.callLater(self.escapeTimeout,
self.endEscape)
elif ch in 'ABCD' and self.inEscape:
self.inEscape = ''
self.escapeCall.cancel()
if ch == 'A':
self.keyReceived('<Up>')
elif ch == 'B':
self.keyReceived('<Down>')
elif ch == 'C':
self.keyReceived('<Right>')
elif ch == 'D':
self.keyReceived('<Left>')
elif self.inEscape:
self.inEscape += ch
else:
self.keyReceived(ch)
def endEscape(self):
ch = self.inEscape
self.inEscape = ''
self.keyReceived(ch)
def initScreen(self):
self.transport.write('\x1b=\x1b[?1h')
def gotoXY(self, x, y):
"""Go to a position on the screen.
"""
self.xpos = x
self.ypos = y
self.commandQueue.append(('gotoxy', x, y))
def writeCh(self, ch):
"""Write a character to the screen. If we're at the end of the row,
ignore the write.
"""
if self.xpos < self.width - 1:
self.commandQueue.append(('write', ch))
self.xpos += 1
def writeStr(self, s):
"""Write a string to the screen. This does not wrap a the edge of the
screen, and stops at \\r and \\n.
"""
s = s[:self.width-self.xpos]
if '\n' in s:
s=s[:s.find('\n')]
if '\r' in s:
s=s[:s.find('\r')]
self.commandQueue.append(('write', s))
self.xpos += len(s)
def eraseToLine(self):
"""Erase from the current position to the end of the line.
"""
self.commandQueue.append(('eraseeol',))
def eraseToScreen(self):
"""Erase from the current position to the end of the screen.
"""
self.commandQueue.append(('eraseeos',))
def clearScreen(self):
"""Clear the screen, and return the cursor to 0, 0.
"""
self.commandQueue = [('cls',)]
self.xpos = self.ypos = 0
def setAttributes(self, *attrs):
"""Set the attributes for drawing on the screen.
"""
self.commandQueue.append(('attributes', attrs))
def refresh(self):
"""Redraw the screen.
"""
redraw = ''
for command in self.commandQueue:
if command[0] == 'gotoxy':
redraw += '\x1b[%i;%iH' % (command[2]+1, command[1]+1)
elif command[0] == 'write':
redraw += command[1]
elif command[0] == 'eraseeol':
redraw += '\x1b[0K'
elif command[0] == 'eraseeos':
redraw += '\x1b[OJ'
elif command[0] == 'cls':
redraw += '\x1b[H\x1b[J'
elif command[0] == 'attributes':
redraw += '\x1b[%sm' % ';'.join(map(str, command[1]))
else:
print command
self.commandQueue = []
self.transport.write(redraw)
def windowSizeChanged(self):
"""Called when the size of the window changes.
Might want to redraw the screen here, or something.
"""
def keyReceived(self, key):
"""Called when the user hits a key.
"""

View File

@ -0,0 +1,29 @@
"""
You don't really want to use this module. Try helper.py instead.
"""
CLEAR = 0
BOLD = 1
DIM = 2
ITALIC = 3
UNDERSCORE = 4
BLINK_SLOW = 5
BLINK_FAST = 6
REVERSE = 7
CONCEALED = 8
FG_BLACK = 30
FG_RED = 31
FG_GREEN = 32
FG_YELLOW = 33
FG_BLUE = 34
FG_MAGENTA = 35
FG_CYAN = 36
FG_WHITE = 37
BG_BLACK = 40
BG_RED = 41
BG_GREEN = 42
BG_YELLOW = 43
BG_BLUE = 44
BG_MAGENTA = 45
BG_CYAN = 46
BG_WHITE = 47

View File

@ -0,0 +1,462 @@
# -*- test-case-name: twisted.conch.test.test_helper -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Partial in-memory terminal emulator
@author: Jp Calderone
"""
import re, string
from zope.interface import implementer
from twisted.internet import defer, protocol, reactor
from twisted.python import log, _textattributes
from twisted.python.deprecate import deprecated, deprecatedModuleAttribute
from twisted.python.versions import Version
from twisted.conch.insults import insults
FOREGROUND = 30
BACKGROUND = 40
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, N_COLORS = range(9)
class _FormattingState(_textattributes._FormattingStateMixin):
"""
Represents the formatting state/attributes of a single character.
Character set, intensity, underlinedness, blinkitude, video
reversal, as well as foreground and background colors made up a
character's attributes.
"""
compareAttributes = (
'charset', 'bold', 'underline', 'blink', 'reverseVideo', 'foreground',
'background', '_subtracting')
def __init__(self, charset=insults.G0, bold=False, underline=False,
blink=False, reverseVideo=False, foreground=WHITE,
background=BLACK, _subtracting=False):
self.charset = charset
self.bold = bold
self.underline = underline
self.blink = blink
self.reverseVideo = reverseVideo
self.foreground = foreground
self.background = background
self._subtracting = _subtracting
@deprecated(Version('Twisted', 13, 1, 0))
def wantOne(self, **kw):
"""
Add a character attribute to a copy of this formatting state.
@param **kw: An optional attribute name and value can be provided with
a keyword argument.
@return: A formatting state instance with the new attribute.
@see: L{DefaultFormattingState._withAttribute}.
"""
k, v = kw.popitem()
return self._withAttribute(k, v)
def toVT102(self):
# Spit out a vt102 control sequence that will set up
# all the attributes set here. Except charset.
attrs = []
if self._subtracting:
attrs.append(0)
if self.bold:
attrs.append(insults.BOLD)
if self.underline:
attrs.append(insults.UNDERLINE)
if self.blink:
attrs.append(insults.BLINK)
if self.reverseVideo:
attrs.append(insults.REVERSE_VIDEO)
if self.foreground != WHITE:
attrs.append(FOREGROUND + self.foreground)
if self.background != BLACK:
attrs.append(BACKGROUND + self.background)
if attrs:
return '\x1b[' + ';'.join(map(str, attrs)) + 'm'
return ''
CharacterAttribute = _FormattingState
deprecatedModuleAttribute(
Version('Twisted', 13, 1, 0),
'Use twisted.conch.insults.text.assembleFormattedText instead.',
'twisted.conch.insults.helper',
'CharacterAttribute')
# XXX - need to support scroll regions and scroll history
@implementer(insults.ITerminalTransport)
class TerminalBuffer(protocol.Protocol):
"""
An in-memory terminal emulator.
"""
for keyID in ('UP_ARROW', 'DOWN_ARROW', 'RIGHT_ARROW', 'LEFT_ARROW',
'HOME', 'INSERT', 'DELETE', 'END', 'PGUP', 'PGDN',
'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9',
'F10', 'F11', 'F12'):
exec '%s = object()' % (keyID,)
TAB = '\t'
BACKSPACE = '\x7f'
width = 80
height = 24
fill = ' '
void = object()
def getCharacter(self, x, y):
return self.lines[y][x]
def connectionMade(self):
self.reset()
def write(self, bytes):
"""
Add the given printable bytes to the terminal.
Line feeds in C{bytes} will be replaced with carriage return / line
feed pairs.
"""
for b in bytes.replace('\n', '\r\n'):
self.insertAtCursor(b)
def _currentFormattingState(self):
return _FormattingState(self.activeCharset, **self.graphicRendition)
def insertAtCursor(self, b):
"""
Add one byte to the terminal at the cursor and make consequent state
updates.
If b is a carriage return, move the cursor to the beginning of the
current row.
If b is a line feed, move the cursor to the next row or scroll down if
the cursor is already in the last row.
Otherwise, if b is printable, put it at the cursor position (inserting
or overwriting as dictated by the current mode) and move the cursor.
"""
if b == '\r':
self.x = 0
elif b == '\n':
self._scrollDown()
elif b in string.printable:
if self.x >= self.width:
self.nextLine()
ch = (b, self._currentFormattingState())
if self.modes.get(insults.modes.IRM):
self.lines[self.y][self.x:self.x] = [ch]
self.lines[self.y].pop()
else:
self.lines[self.y][self.x] = ch
self.x += 1
def _emptyLine(self, width):
return [(self.void, self._currentFormattingState())
for i in xrange(width)]
def _scrollDown(self):
self.y += 1
if self.y >= self.height:
self.y -= 1
del self.lines[0]
self.lines.append(self._emptyLine(self.width))
def _scrollUp(self):
self.y -= 1
if self.y < 0:
self.y = 0
del self.lines[-1]
self.lines.insert(0, self._emptyLine(self.width))
def cursorUp(self, n=1):
self.y = max(0, self.y - n)
def cursorDown(self, n=1):
self.y = min(self.height - 1, self.y + n)
def cursorBackward(self, n=1):
self.x = max(0, self.x - n)
def cursorForward(self, n=1):
self.x = min(self.width, self.x + n)
def cursorPosition(self, column, line):
self.x = column
self.y = line
def cursorHome(self):
self.x = self.home.x
self.y = self.home.y
def index(self):
self._scrollDown()
def reverseIndex(self):
self._scrollUp()
def nextLine(self):
"""
Update the cursor position attributes and scroll down if appropriate.
"""
self.x = 0
self._scrollDown()
def saveCursor(self):
self._savedCursor = (self.x, self.y)
def restoreCursor(self):
self.x, self.y = self._savedCursor
del self._savedCursor
def setModes(self, modes):
for m in modes:
self.modes[m] = True
def resetModes(self, modes):
for m in modes:
try:
del self.modes[m]
except KeyError:
pass
def setPrivateModes(self, modes):
"""
Enable the given modes.
Track which modes have been enabled so that the implementations of
other L{insults.ITerminalTransport} methods can be properly implemented
to respect these settings.
@see: L{resetPrivateModes}
@see: L{insults.ITerminalTransport.setPrivateModes}
"""
for m in modes:
self.privateModes[m] = True
def resetPrivateModes(self, modes):
"""
Disable the given modes.
@see: L{setPrivateModes}
@see: L{insults.ITerminalTransport.resetPrivateModes}
"""
for m in modes:
try:
del self.privateModes[m]
except KeyError:
pass
def applicationKeypadMode(self):
self.keypadMode = 'app'
def numericKeypadMode(self):
self.keypadMode = 'num'
def selectCharacterSet(self, charSet, which):
self.charsets[which] = charSet
def shiftIn(self):
self.activeCharset = insults.G0
def shiftOut(self):
self.activeCharset = insults.G1
def singleShift2(self):
oldActiveCharset = self.activeCharset
self.activeCharset = insults.G2
f = self.insertAtCursor
def insertAtCursor(b):
f(b)
del self.insertAtCursor
self.activeCharset = oldActiveCharset
self.insertAtCursor = insertAtCursor
def singleShift3(self):
oldActiveCharset = self.activeCharset
self.activeCharset = insults.G3
f = self.insertAtCursor
def insertAtCursor(b):
f(b)
del self.insertAtCursor
self.activeCharset = oldActiveCharset
self.insertAtCursor = insertAtCursor
def selectGraphicRendition(self, *attributes):
for a in attributes:
if a == insults.NORMAL:
self.graphicRendition = {
'bold': False,
'underline': False,
'blink': False,
'reverseVideo': False,
'foreground': WHITE,
'background': BLACK}
elif a == insults.BOLD:
self.graphicRendition['bold'] = True
elif a == insults.UNDERLINE:
self.graphicRendition['underline'] = True
elif a == insults.BLINK:
self.graphicRendition['blink'] = True
elif a == insults.REVERSE_VIDEO:
self.graphicRendition['reverseVideo'] = True
else:
try:
v = int(a)
except ValueError:
log.msg("Unknown graphic rendition attribute: " + repr(a))
else:
if FOREGROUND <= v <= FOREGROUND + N_COLORS:
self.graphicRendition['foreground'] = v - FOREGROUND
elif BACKGROUND <= v <= BACKGROUND + N_COLORS:
self.graphicRendition['background'] = v - BACKGROUND
else:
log.msg("Unknown graphic rendition attribute: " + repr(a))
def eraseLine(self):
self.lines[self.y] = self._emptyLine(self.width)
def eraseToLineEnd(self):
width = self.width - self.x
self.lines[self.y][self.x:] = self._emptyLine(width)
def eraseToLineBeginning(self):
self.lines[self.y][:self.x + 1] = self._emptyLine(self.x + 1)
def eraseDisplay(self):
self.lines = [self._emptyLine(self.width) for i in xrange(self.height)]
def eraseToDisplayEnd(self):
self.eraseToLineEnd()
height = self.height - self.y - 1
self.lines[self.y + 1:] = [self._emptyLine(self.width) for i in range(height)]
def eraseToDisplayBeginning(self):
self.eraseToLineBeginning()
self.lines[:self.y] = [self._emptyLine(self.width) for i in range(self.y)]
def deleteCharacter(self, n=1):
del self.lines[self.y][self.x:self.x+n]
self.lines[self.y].extend(self._emptyLine(min(self.width - self.x, n)))
def insertLine(self, n=1):
self.lines[self.y:self.y] = [self._emptyLine(self.width) for i in range(n)]
del self.lines[self.height:]
def deleteLine(self, n=1):
del self.lines[self.y:self.y+n]
self.lines.extend([self._emptyLine(self.width) for i in range(n)])
def reportCursorPosition(self):
return (self.x, self.y)
def reset(self):
self.home = insults.Vector(0, 0)
self.x = self.y = 0
self.modes = {}
self.privateModes = {}
self.setPrivateModes([insults.privateModes.AUTO_WRAP,
insults.privateModes.CURSOR_MODE])
self.numericKeypad = 'app'
self.activeCharset = insults.G0
self.graphicRendition = {
'bold': False,
'underline': False,
'blink': False,
'reverseVideo': False,
'foreground': WHITE,
'background': BLACK}
self.charsets = {
insults.G0: insults.CS_US,
insults.G1: insults.CS_US,
insults.G2: insults.CS_ALTERNATE,
insults.G3: insults.CS_ALTERNATE_SPECIAL}
self.eraseDisplay()
def unhandledControlSequence(self, buf):
print 'Could not handle', repr(buf)
def __str__(self):
lines = []
for L in self.lines:
buf = []
length = 0
for (ch, attr) in L:
if ch is not self.void:
buf.append(ch)
length = len(buf)
else:
buf.append(self.fill)
lines.append(''.join(buf[:length]))
return '\n'.join(lines)
class ExpectationTimeout(Exception):
pass
class ExpectableBuffer(TerminalBuffer):
_mark = 0
def connectionMade(self):
TerminalBuffer.connectionMade(self)
self._expecting = []
def write(self, bytes):
TerminalBuffer.write(self, bytes)
self._checkExpected()
def cursorHome(self):
TerminalBuffer.cursorHome(self)
self._mark = 0
def _timeoutExpected(self, d):
d.errback(ExpectationTimeout())
self._checkExpected()
def _checkExpected(self):
s = str(self)[self._mark:]
while self._expecting:
expr, timer, deferred = self._expecting[0]
if timer and not timer.active():
del self._expecting[0]
continue
for match in expr.finditer(s):
if timer:
timer.cancel()
del self._expecting[0]
self._mark += match.end()
s = s[match.end():]
deferred.callback(match)
break
else:
return
def expect(self, expression, timeout=None, scheduler=reactor):
d = defer.Deferred()
timer = None
if timeout:
timer = scheduler.callLater(timeout, self._timeoutExpected, d)
self._expecting.append((re.compile(expression), timer, d))
self._checkExpected()
return d
__all__ = [
'CharacterAttribute', 'TerminalBuffer', 'ExpectableBuffer']

View File

@ -0,0 +1,175 @@
# -*- test-case-name: twisted.conch.test.test_text -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Character attribute manipulation API.
This module provides a domain-specific language (using Python syntax)
for the creation of text with additional display attributes associated
with it. It is intended as an alternative to manually building up
strings containing ECMA 48 character attribute control codes. It
currently supports foreground and background colors (black, red,
green, yellow, blue, magenta, cyan, and white), intensity selection,
underlining, blinking and reverse video. Character set selection
support is planned.
Character attributes are specified by using two Python operations:
attribute lookup and indexing. For example, the string \"Hello
world\" with red foreground and all other attributes set to their
defaults, assuming the name twisted.conch.insults.text.attributes has
been imported and bound to the name \"A\" (with the statement C{from
twisted.conch.insults.text import attributes as A}, for example) one
uses this expression::
A.fg.red[\"Hello world\"]
Other foreground colors are set by substituting their name for
\"red\". To set both a foreground and a background color, this
expression is used::
A.fg.red[A.bg.green[\"Hello world\"]]
Note that either A.bg.green can be nested within A.fg.red or vice
versa. Also note that multiple items can be nested within a single
index operation by separating them with commas::
A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]
Other character attributes are set in a similar fashion. To specify a
blinking version of the previous expression::
A.blink[A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]]
C{A.reverseVideo}, C{A.underline}, and C{A.bold} are also valid.
A third operation is actually supported: unary negation. This turns
off an attribute when an enclosing expression would otherwise have
caused it to be on. For example::
A.underline[A.fg.red[\"Hello\", -A.underline[\" world\"]]]
A formatting structure can then be serialized into a string containing the
necessary VT102 control codes with L{assembleFormattedText}.
@see: L{twisted.conch.insults.text.attributes}
@author: Jp Calderone
"""
from twisted.conch.insults import helper, insults
from twisted.python import _textattributes
from twisted.python.deprecate import deprecatedModuleAttribute
from twisted.python.versions import Version
flatten = _textattributes.flatten
deprecatedModuleAttribute(
Version('Twisted', 13, 1, 0),
'Use twisted.conch.insults.text.assembleFormattedText instead.',
'twisted.conch.insults.text',
'flatten')
_TEXT_COLORS = {
'black': helper.BLACK,
'red': helper.RED,
'green': helper.GREEN,
'yellow': helper.YELLOW,
'blue': helper.BLUE,
'magenta': helper.MAGENTA,
'cyan': helper.CYAN,
'white': helper.WHITE}
class _CharacterAttributes(_textattributes.CharacterAttributesMixin):
"""
Factory for character attributes, including foreground and background color
and non-color attributes such as bold, reverse video and underline.
Character attributes are applied to actual text by using object
indexing-syntax (C{obj['abc']}) after accessing a factory attribute, for
example::
attributes.bold['Some text']
These can be nested to mix attributes::
attributes.bold[attributes.underline['Some text']]
And multiple values can be passed::
attributes.normal[attributes.bold['Some'], ' text']
Non-color attributes can be accessed by attribute name, available
attributes are:
- bold
- blink
- reverseVideo
- underline
Available colors are:
0. black
1. red
2. green
3. yellow
4. blue
5. magenta
6. cyan
7. white
@ivar fg: Foreground colors accessed by attribute name, see above
for possible names.
@ivar bg: Background colors accessed by attribute name, see above
for possible names.
"""
fg = _textattributes._ColorAttribute(
_textattributes._ForegroundColorAttr, _TEXT_COLORS)
bg = _textattributes._ColorAttribute(
_textattributes._BackgroundColorAttr, _TEXT_COLORS)
attrs = {
'bold': insults.BOLD,
'blink': insults.BLINK,
'underline': insults.UNDERLINE,
'reverseVideo': insults.REVERSE_VIDEO}
def assembleFormattedText(formatted):
"""
Assemble formatted text from structured information.
Currently handled formatting includes: bold, blink, reverse, underline and
color codes.
For example::
from twisted.conch.insults.text import attributes as A
assembleFormattedText(
A.normal[A.bold['Time: '], A.fg.lightRed['Now!']])
Would produce "Time: " in bold formatting, followed by "Now!" with a
foreground color of light red and without any additional formatting.
@param formatted: Structured text and attributes.
@rtype: C{str}
@return: String containing VT102 control sequences that mimic those
specified by L{formatted}.
@see: L{twisted.conch.insults.text.attributes}
@since: 13.1
"""
return _textattributes.flatten(
formatted, helper._FormattingState(), 'toVT102')
attributes = _CharacterAttributes()
__all__ = ['attributes', 'flatten']

View File

@ -0,0 +1,868 @@
# -*- test-case-name: twisted.conch.test.test_window -*-
"""
Simple insults-based widget library
@author: Jp Calderone
"""
import array
from twisted.conch.insults import insults, helper
from twisted.python import text as tptext
class YieldFocus(Exception):
"""Input focus manipulation exception
"""
class BoundedTerminalWrapper(object):
def __init__(self, terminal, width, height, xoff, yoff):
self.width = width
self.height = height
self.xoff = xoff
self.yoff = yoff
self.terminal = terminal
self.cursorForward = terminal.cursorForward
self.selectCharacterSet = terminal.selectCharacterSet
self.selectGraphicRendition = terminal.selectGraphicRendition
self.saveCursor = terminal.saveCursor
self.restoreCursor = terminal.restoreCursor
def cursorPosition(self, x, y):
return self.terminal.cursorPosition(
self.xoff + min(self.width, x),
self.yoff + min(self.height, y)
)
def cursorHome(self):
return self.terminal.cursorPosition(
self.xoff, self.yoff)
def write(self, bytes):
return self.terminal.write(bytes)
class Widget(object):
focused = False
parent = None
dirty = False
width = height = None
def repaint(self):
if not self.dirty:
self.dirty = True
if self.parent is not None and not self.parent.dirty:
self.parent.repaint()
def filthy(self):
self.dirty = True
def redraw(self, width, height, terminal):
self.filthy()
self.draw(width, height, terminal)
def draw(self, width, height, terminal):
if width != self.width or height != self.height or self.dirty:
self.width = width
self.height = height
self.dirty = False
self.render(width, height, terminal)
def render(self, width, height, terminal):
pass
def sizeHint(self):
return None
def keystrokeReceived(self, keyID, modifier):
if keyID == '\t':
self.tabReceived(modifier)
elif keyID == '\x7f':
self.backspaceReceived()
elif keyID in insults.FUNCTION_KEYS:
self.functionKeyReceived(keyID, modifier)
else:
self.characterReceived(keyID, modifier)
def tabReceived(self, modifier):
# XXX TODO - Handle shift+tab
raise YieldFocus()
def focusReceived(self):
"""Called when focus is being given to this widget.
May raise YieldFocus is this widget does not want focus.
"""
self.focused = True
self.repaint()
def focusLost(self):
self.focused = False
self.repaint()
def backspaceReceived(self):
pass
def functionKeyReceived(self, keyID, modifier):
func = getattr(self, 'func_' + keyID.name, None)
if func is not None:
func(modifier)
def characterReceived(self, keyID, modifier):
pass
class ContainerWidget(Widget):
"""
@ivar focusedChild: The contained widget which currently has
focus, or None.
"""
focusedChild = None
focused = False
def __init__(self):
Widget.__init__(self)
self.children = []
def addChild(self, child):
assert child.parent is None
child.parent = self
self.children.append(child)
if self.focusedChild is None and self.focused:
try:
child.focusReceived()
except YieldFocus:
pass
else:
self.focusedChild = child
self.repaint()
def remChild(self, child):
assert child.parent is self
child.parent = None
self.children.remove(child)
self.repaint()
def filthy(self):
for ch in self.children:
ch.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
for ch in self.children:
ch.draw(width, height, terminal)
def changeFocus(self):
self.repaint()
if self.focusedChild is not None:
self.focusedChild.focusLost()
focusedChild = self.focusedChild
self.focusedChild = None
try:
curFocus = self.children.index(focusedChild) + 1
except ValueError:
raise YieldFocus()
else:
curFocus = 0
while curFocus < len(self.children):
try:
self.children[curFocus].focusReceived()
except YieldFocus:
curFocus += 1
else:
self.focusedChild = self.children[curFocus]
return
# None of our children wanted focus
raise YieldFocus()
def focusReceived(self):
self.changeFocus()
self.focused = True
def keystrokeReceived(self, keyID, modifier):
if self.focusedChild is not None:
try:
self.focusedChild.keystrokeReceived(keyID, modifier)
except YieldFocus:
self.changeFocus()
self.repaint()
else:
Widget.keystrokeReceived(self, keyID, modifier)
class TopWindow(ContainerWidget):
"""
A top-level container object which provides focus wrap-around and paint
scheduling.
@ivar painter: A no-argument callable which will be invoked when this
widget needs to be redrawn.
@ivar scheduler: A one-argument callable which will be invoked with a
no-argument callable and should arrange for it to invoked at some point in
the near future. The no-argument callable will cause this widget and all
its children to be redrawn. It is typically beneficial for the no-argument
callable to be invoked at the end of handling for whatever event is
currently active; for example, it might make sense to call it at the end of
L{twisted.conch.insults.insults.ITerminalProtocol.keystrokeReceived}.
Note, however, that since calls to this may also be made in response to no
apparent event, arrangements should be made for the function to be called
even if an event handler such as C{keystrokeReceived} is not on the call
stack (eg, using C{reactor.callLater} with a short timeout).
"""
focused = True
def __init__(self, painter, scheduler):
ContainerWidget.__init__(self)
self.painter = painter
self.scheduler = scheduler
_paintCall = None
def repaint(self):
if self._paintCall is None:
self._paintCall = object()
self.scheduler(self._paint)
ContainerWidget.repaint(self)
def _paint(self):
self._paintCall = None
self.painter()
def changeFocus(self):
try:
ContainerWidget.changeFocus(self)
except YieldFocus:
try:
ContainerWidget.changeFocus(self)
except YieldFocus:
pass
def keystrokeReceived(self, keyID, modifier):
try:
ContainerWidget.keystrokeReceived(self, keyID, modifier)
except YieldFocus:
self.changeFocus()
class AbsoluteBox(ContainerWidget):
def moveChild(self, child, x, y):
for n in range(len(self.children)):
if self.children[n][0] is child:
self.children[n] = (child, x, y)
break
else:
raise ValueError("No such child", child)
def render(self, width, height, terminal):
for (ch, x, y) in self.children:
wrap = BoundedTerminalWrapper(terminal, width - x, height - y, x, y)
ch.draw(width, height, wrap)
class _Box(ContainerWidget):
TOP, CENTER, BOTTOM = range(3)
def __init__(self, gravity=CENTER):
ContainerWidget.__init__(self)
self.gravity = gravity
def sizeHint(self):
height = 0
width = 0
for ch in self.children:
hint = ch.sizeHint()
if hint is None:
hint = (None, None)
if self.variableDimension == 0:
if hint[0] is None:
width = None
elif width is not None:
width += hint[0]
if hint[1] is None:
height = None
elif height is not None:
height = max(height, hint[1])
else:
if hint[0] is None:
width = None
elif width is not None:
width = max(width, hint[0])
if hint[1] is None:
height = None
elif height is not None:
height += hint[1]
return width, height
def render(self, width, height, terminal):
if not self.children:
return
greedy = 0
wants = []
for ch in self.children:
hint = ch.sizeHint()
if hint is None:
hint = (None, None)
if hint[self.variableDimension] is None:
greedy += 1
wants.append(hint[self.variableDimension])
length = (width, height)[self.variableDimension]
totalWant = sum([w for w in wants if w is not None])
if greedy:
leftForGreedy = int((length - totalWant) / greedy)
widthOffset = heightOffset = 0
for want, ch in zip(wants, self.children):
if want is None:
want = leftForGreedy
subWidth, subHeight = width, height
if self.variableDimension == 0:
subWidth = want
else:
subHeight = want
wrap = BoundedTerminalWrapper(
terminal,
subWidth,
subHeight,
widthOffset,
heightOffset,
)
ch.draw(subWidth, subHeight, wrap)
if self.variableDimension == 0:
widthOffset += want
else:
heightOffset += want
class HBox(_Box):
variableDimension = 0
class VBox(_Box):
variableDimension = 1
class Packer(ContainerWidget):
def render(self, width, height, terminal):
if not self.children:
return
root = int(len(self.children) ** 0.5 + 0.5)
boxes = [VBox() for n in range(root)]
for n, ch in enumerate(self.children):
boxes[n % len(boxes)].addChild(ch)
h = HBox()
map(h.addChild, boxes)
h.render(width, height, terminal)
class Canvas(Widget):
focused = False
contents = None
def __init__(self):
Widget.__init__(self)
self.resize(1, 1)
def resize(self, width, height):
contents = array.array('c', ' ' * width * height)
if self.contents is not None:
for x in range(min(width, self._width)):
for y in range(min(height, self._height)):
contents[width * y + x] = self[x, y]
self.contents = contents
self._width = width
self._height = height
if self.x >= width:
self.x = width - 1
if self.y >= height:
self.y = height - 1
def __getitem__(self, (x, y)):
return self.contents[(self._width * y) + x]
def __setitem__(self, (x, y), value):
self.contents[(self._width * y) + x] = value
def clear(self):
self.contents = array.array('c', ' ' * len(self.contents))
def render(self, width, height, terminal):
if not width or not height:
return
if width != self._width or height != self._height:
self.resize(width, height)
for i in range(height):
terminal.cursorPosition(0, i)
terminal.write(''.join(self.contents[self._width * i:self._width * i + self._width])[:width])
def horizontalLine(terminal, y, left, right):
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
terminal.cursorPosition(left, y)
terminal.write(chr(0161) * (right - left))
terminal.selectCharacterSet(insults.CS_US, insults.G0)
def verticalLine(terminal, x, top, bottom):
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
for n in xrange(top, bottom):
terminal.cursorPosition(x, n)
terminal.write(chr(0170))
terminal.selectCharacterSet(insults.CS_US, insults.G0)
def rectangle(terminal, (top, left), (width, height)):
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
terminal.cursorPosition(top, left)
terminal.write(chr(0154))
terminal.write(chr(0161) * (width - 2))
terminal.write(chr(0153))
for n in range(height - 2):
terminal.cursorPosition(left, top + n + 1)
terminal.write(chr(0170))
terminal.cursorForward(width - 2)
terminal.write(chr(0170))
terminal.cursorPosition(0, top + height - 1)
terminal.write(chr(0155))
terminal.write(chr(0161) * (width - 2))
terminal.write(chr(0152))
terminal.selectCharacterSet(insults.CS_US, insults.G0)
class Border(Widget):
def __init__(self, containee):
Widget.__init__(self)
self.containee = containee
self.containee.parent = self
def focusReceived(self):
return self.containee.focusReceived()
def focusLost(self):
return self.containee.focusLost()
def keystrokeReceived(self, keyID, modifier):
return self.containee.keystrokeReceived(keyID, modifier)
def sizeHint(self):
hint = self.containee.sizeHint()
if hint is None:
hint = (None, None)
if hint[0] is None:
x = None
else:
x = hint[0] + 2
if hint[1] is None:
y = None
else:
y = hint[1] + 2
return x, y
def filthy(self):
self.containee.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
if self.containee.focused:
terminal.write('\x1b[31m')
rectangle(terminal, (0, 0), (width, height))
terminal.write('\x1b[0m')
wrap = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
self.containee.draw(width - 2, height - 2, wrap)
class Button(Widget):
def __init__(self, label, onPress):
Widget.__init__(self)
self.label = label
self.onPress = onPress
def sizeHint(self):
return len(self.label), 1
def characterReceived(self, keyID, modifier):
if keyID == '\r':
self.onPress()
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
if self.focused:
terminal.write('\x1b[1m' + self.label + '\x1b[0m')
else:
terminal.write(self.label)
class TextInput(Widget):
def __init__(self, maxwidth, onSubmit):
Widget.__init__(self)
self.onSubmit = onSubmit
self.maxwidth = maxwidth
self.buffer = ''
self.cursor = 0
def setText(self, text):
self.buffer = text[:self.maxwidth]
self.cursor = len(self.buffer)
self.repaint()
def func_LEFT_ARROW(self, modifier):
if self.cursor > 0:
self.cursor -= 1
self.repaint()
def func_RIGHT_ARROW(self, modifier):
if self.cursor < len(self.buffer):
self.cursor += 1
self.repaint()
def backspaceReceived(self):
if self.cursor > 0:
self.buffer = self.buffer[:self.cursor - 1] + self.buffer[self.cursor:]
self.cursor -= 1
self.repaint()
def characterReceived(self, keyID, modifier):
if keyID == '\r':
self.onSubmit(self.buffer)
else:
if len(self.buffer) < self.maxwidth:
self.buffer = self.buffer[:self.cursor] + keyID + self.buffer[self.cursor:]
self.cursor += 1
self.repaint()
def sizeHint(self):
return self.maxwidth + 1, 1
def render(self, width, height, terminal):
currentText = self._renderText()
terminal.cursorPosition(0, 0)
if self.focused:
terminal.write(currentText[:self.cursor])
cursor(terminal, currentText[self.cursor:self.cursor+1] or ' ')
terminal.write(currentText[self.cursor+1:])
terminal.write(' ' * (self.maxwidth - len(currentText) + 1))
else:
more = self.maxwidth - len(currentText)
terminal.write(currentText + '_' * more)
def _renderText(self):
return self.buffer
class PasswordInput(TextInput):
def _renderText(self):
return '*' * len(self.buffer)
class TextOutput(Widget):
text = ''
def __init__(self, size=None):
Widget.__init__(self)
self.size = size
def sizeHint(self):
return self.size
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
text = self.text[:width]
terminal.write(text + ' ' * (width - len(text)))
def setText(self, text):
self.text = text
self.repaint()
def focusReceived(self):
raise YieldFocus()
class TextOutputArea(TextOutput):
WRAP, TRUNCATE = range(2)
def __init__(self, size=None, longLines=WRAP):
TextOutput.__init__(self, size)
self.longLines = longLines
def render(self, width, height, terminal):
n = 0
inputLines = self.text.splitlines()
outputLines = []
while inputLines:
if self.longLines == self.WRAP:
wrappedLines = tptext.greedyWrap(inputLines.pop(0), width)
outputLines.extend(wrappedLines or [''])
else:
outputLines.append(inputLines.pop(0)[:width])
if len(outputLines) >= height:
break
for n, L in enumerate(outputLines[:height]):
terminal.cursorPosition(0, n)
terminal.write(L)
class Viewport(Widget):
_xOffset = 0
_yOffset = 0
def xOffset():
def get(self):
return self._xOffset
def set(self, value):
if self._xOffset != value:
self._xOffset = value
self.repaint()
return get, set
xOffset = property(*xOffset())
def yOffset():
def get(self):
return self._yOffset
def set(self, value):
if self._yOffset != value:
self._yOffset = value
self.repaint()
return get, set
yOffset = property(*yOffset())
_width = 160
_height = 24
def __init__(self, containee):
Widget.__init__(self)
self.containee = containee
self.containee.parent = self
self._buf = helper.TerminalBuffer()
self._buf.width = self._width
self._buf.height = self._height
self._buf.connectionMade()
def filthy(self):
self.containee.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
self.containee.draw(self._width, self._height, self._buf)
# XXX /Lame/
for y, line in enumerate(self._buf.lines[self._yOffset:self._yOffset + height]):
terminal.cursorPosition(0, y)
n = 0
for n, (ch, attr) in enumerate(line[self._xOffset:self._xOffset + width]):
if ch is self._buf.void:
ch = ' '
terminal.write(ch)
if n < width:
terminal.write(' ' * (width - n - 1))
class _Scrollbar(Widget):
def __init__(self, onScroll):
Widget.__init__(self)
self.onScroll = onScroll
self.percent = 0.0
def smaller(self):
self.percent = min(1.0, max(0.0, self.onScroll(-1)))
self.repaint()
def bigger(self):
self.percent = min(1.0, max(0.0, self.onScroll(+1)))
self.repaint()
class HorizontalScrollbar(_Scrollbar):
def sizeHint(self):
return (None, 1)
def func_LEFT_ARROW(self, modifier):
self.smaller()
def func_RIGHT_ARROW(self, modifier):
self.bigger()
_left = u'\N{BLACK LEFT-POINTING TRIANGLE}'
_right = u'\N{BLACK RIGHT-POINTING TRIANGLE}'
_bar = u'\N{LIGHT SHADE}'
_slider = u'\N{DARK SHADE}'
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
n = width - 3
before = int(n * self.percent)
after = n - before
me = self._left + (self._bar * before) + self._slider + (self._bar * after) + self._right
terminal.write(me.encode('utf-8'))
class VerticalScrollbar(_Scrollbar):
def sizeHint(self):
return (1, None)
def func_UP_ARROW(self, modifier):
self.smaller()
def func_DOWN_ARROW(self, modifier):
self.bigger()
_up = u'\N{BLACK UP-POINTING TRIANGLE}'
_down = u'\N{BLACK DOWN-POINTING TRIANGLE}'
_bar = u'\N{LIGHT SHADE}'
_slider = u'\N{DARK SHADE}'
def render(self, width, height, terminal):
terminal.cursorPosition(0, 0)
knob = int(self.percent * (height - 2))
terminal.write(self._up.encode('utf-8'))
for i in xrange(1, height - 1):
terminal.cursorPosition(0, i)
if i != (knob + 1):
terminal.write(self._bar.encode('utf-8'))
else:
terminal.write(self._slider.encode('utf-8'))
terminal.cursorPosition(0, height - 1)
terminal.write(self._down.encode('utf-8'))
class ScrolledArea(Widget):
"""
A L{ScrolledArea} contains another widget wrapped in a viewport and
vertical and horizontal scrollbars for moving the viewport around.
"""
def __init__(self, containee):
Widget.__init__(self)
self._viewport = Viewport(containee)
self._horiz = HorizontalScrollbar(self._horizScroll)
self._vert = VerticalScrollbar(self._vertScroll)
for w in self._viewport, self._horiz, self._vert:
w.parent = self
def _horizScroll(self, n):
self._viewport.xOffset += n
self._viewport.xOffset = max(0, self._viewport.xOffset)
return self._viewport.xOffset / 25.0
def _vertScroll(self, n):
self._viewport.yOffset += n
self._viewport.yOffset = max(0, self._viewport.yOffset)
return self._viewport.yOffset / 25.0
def func_UP_ARROW(self, modifier):
self._vert.smaller()
def func_DOWN_ARROW(self, modifier):
self._vert.bigger()
def func_LEFT_ARROW(self, modifier):
self._horiz.smaller()
def func_RIGHT_ARROW(self, modifier):
self._horiz.bigger()
def filthy(self):
self._viewport.filthy()
self._horiz.filthy()
self._vert.filthy()
Widget.filthy(self)
def render(self, width, height, terminal):
wrapper = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
self._viewport.draw(width - 2, height - 2, wrapper)
if self.focused:
terminal.write('\x1b[31m')
horizontalLine(terminal, 0, 1, width - 1)
verticalLine(terminal, 0, 1, height - 1)
self._vert.draw(1, height - 1, BoundedTerminalWrapper(terminal, 1, height - 1, width - 1, 0))
self._horiz.draw(width, 1, BoundedTerminalWrapper(terminal, width, 1, 0, height - 1))
terminal.write('\x1b[0m')
def cursor(terminal, ch):
terminal.saveCursor()
terminal.selectGraphicRendition(str(insults.REVERSE_VIDEO))
terminal.write(ch)
terminal.restoreCursor()
terminal.cursorForward()
class Selection(Widget):
# Index into the sequence
focusedIndex = 0
# Offset into the displayed subset of the sequence
renderOffset = 0
def __init__(self, sequence, onSelect, minVisible=None):
Widget.__init__(self)
self.sequence = sequence
self.onSelect = onSelect
self.minVisible = minVisible
if minVisible is not None:
self._width = max(map(len, self.sequence))
def sizeHint(self):
if self.minVisible is not None:
return self._width, self.minVisible
def func_UP_ARROW(self, modifier):
if self.focusedIndex > 0:
self.focusedIndex -= 1
if self.renderOffset > 0:
self.renderOffset -= 1
self.repaint()
def func_PGUP(self, modifier):
if self.renderOffset != 0:
self.focusedIndex -= self.renderOffset
self.renderOffset = 0
else:
self.focusedIndex = max(0, self.focusedIndex - self.height)
self.repaint()
def func_DOWN_ARROW(self, modifier):
if self.focusedIndex < len(self.sequence) - 1:
self.focusedIndex += 1
if self.renderOffset < self.height - 1:
self.renderOffset += 1
self.repaint()
def func_PGDN(self, modifier):
if self.renderOffset != self.height - 1:
change = self.height - self.renderOffset - 1
if change + self.focusedIndex >= len(self.sequence):
change = len(self.sequence) - self.focusedIndex - 1
self.focusedIndex += change
self.renderOffset = self.height - 1
else:
self.focusedIndex = min(len(self.sequence) - 1, self.focusedIndex + self.height)
self.repaint()
def characterReceived(self, keyID, modifier):
if keyID == '\r':
self.onSelect(self.sequence[self.focusedIndex])
def render(self, width, height, terminal):
self.height = height
start = self.focusedIndex - self.renderOffset
if start > len(self.sequence) - height:
start = max(0, len(self.sequence) - height)
elements = self.sequence[start:start+height]
for n, ele in enumerate(elements):
terminal.cursorPosition(0, n)
if n == self.renderOffset:
terminal.saveCursor()
if self.focused:
modes = str(insults.REVERSE_VIDEO), str(insults.BOLD)
else:
modes = str(insults.REVERSE_VIDEO),
terminal.selectGraphicRendition(*modes)
text = ele[:width]
terminal.write(text + (' ' * (width - len(text))))
if n == self.renderOffset:
terminal.restoreCursor()

View File

@ -0,0 +1,408 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains interfaces defined for the L{twisted.conch} package.
"""
from zope.interface import Interface, Attribute
class IConchUser(Interface):
"""
A user who has been authenticated to Cred through Conch. This is
the interface between the SSH connection and the user.
"""
conn = Attribute('The SSHConnection object for this user.')
def lookupChannel(channelType, windowSize, maxPacket, data):
"""
The other side requested a channel of some sort.
channelType is the type of channel being requested,
windowSize is the initial size of the remote window,
maxPacket is the largest packet we should send,
data is any other packet data (often nothing).
We return a subclass of L{SSHChannel<ssh.channel.SSHChannel>}. If
an appropriate channel can not be found, an exception will be
raised. If a L{ConchError<error.ConchError>} is raised, the .value
will be the message, and the .data will be the error code.
@type channelType: C{str}
@type windowSize: C{int}
@type maxPacket: C{int}
@type data: C{str}
@rtype: subclass of L{SSHChannel}/C{tuple}
"""
def lookupSubsystem(subsystem, data):
"""
The other side requested a subsystem.
subsystem is the name of the subsystem being requested.
data is any other packet data (often nothing).
We return a L{Protocol}.
"""
def gotGlobalRequest(requestType, data):
"""
A global request was sent from the other side.
By default, this dispatches to a method 'channel_channelType' with any
non-alphanumerics in the channelType replace with _'s. If it cannot
find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
The method is called with arguments of windowSize, maxPacket, data.
"""
class ISession(Interface):
def getPty(term, windowSize, modes):
"""
Get a pseudo-terminal for use by a shell or command.
If a pseudo-terminal is not available, or the request otherwise
fails, raise an exception.
"""
def openShell(proto):
"""
Open a shell and connect it to proto.
@param proto: a L{ProcessProtocol} instance.
"""
def execCommand(proto, command):
"""
Execute a command.
@param proto: a L{ProcessProtocol} instance.
"""
def windowChanged(newWindowSize):
"""
Called when the size of the remote screen has changed.
"""
def eofReceived():
"""
Called when the other side has indicated no more data will be sent.
"""
def closed():
"""
Called when the session is closed.
"""
class ISFTPServer(Interface):
"""
SFTP subsystem for server-side communication.
Each method should check to verify that the user has permission for
their actions.
"""
avatar = Attribute(
"""
The avatar returned by the Realm that we are authenticated with,
and represents the logged-in user.
""")
def gotVersion(otherVersion, extData):
"""
Called when the client sends their version info.
otherVersion is an integer representing the version of the SFTP
protocol they are claiming.
extData is a dictionary of extended_name : extended_data items.
These items are sent by the client to indicate additional features.
This method should return a dictionary of extended_name : extended_data
items. These items are the additional features (if any) supported
by the server.
"""
return {}
def openFile(filename, flags, attrs):
"""
Called when the clients asks to open a file.
@param filename: a string representing the file to open.
@param flags: an integer of the flags to open the file with, ORed together.
The flags and their values are listed at the bottom of this file.
@param attrs: a list of attributes to open the file with. It is a
dictionary, consisting of 0 or more keys. The possible keys are::
size: the size of the file in bytes
uid: the user ID of the file as an integer
gid: the group ID of the file as an integer
permissions: the permissions of the file with as an integer.
the bit representation of this field is defined by POSIX.
atime: the access time of the file as seconds since the epoch.
mtime: the modification time of the file as seconds since the epoch.
ext_*: extended attributes. The server is not required to
understand this, but it may.
NOTE: there is no way to indicate text or binary files. it is up
to the SFTP client to deal with this.
This method returns an object that meets the ISFTPFile interface.
Alternatively, it can return a L{Deferred} that will be called back
with the object.
"""
def removeFile(filename):
"""
Remove the given file.
This method returns when the remove succeeds, or a Deferred that is
called back when it succeeds.
@param filename: the name of the file as a string.
"""
def renameFile(oldpath, newpath):
"""
Rename the given file.
This method returns when the rename succeeds, or a L{Deferred} that is
called back when it succeeds. If the rename fails, C{renameFile} will
raise an implementation-dependent exception.
@param oldpath: the current location of the file.
@param newpath: the new file name.
"""
def makeDirectory(path, attrs):
"""
Make a directory.
This method returns when the directory is created, or a Deferred that
is called back when it is created.
@param path: the name of the directory to create as a string.
@param attrs: a dictionary of attributes to create the directory with.
Its meaning is the same as the attrs in the L{openFile} method.
"""
def removeDirectory(path):
"""
Remove a directory (non-recursively)
It is an error to remove a directory that has files or directories in
it.
This method returns when the directory is removed, or a Deferred that
is called back when it is removed.
@param path: the directory to remove.
"""
def openDirectory(path):
"""
Open a directory for scanning.
This method returns an iterable object that has a close() method,
or a Deferred that is called back with same.
The close() method is called when the client is finished reading
from the directory. At this point, the iterable will no longer
be used.
The iterable should return triples of the form (filename,
longname, attrs) or Deferreds that return the same. The
sequence must support __getitem__, but otherwise may be any
'sequence-like' object.
filename is the name of the file relative to the directory.
logname is an expanded format of the filename. The recommended format
is:
-rwxr-xr-x 1 mjos staff 348911 Mar 25 14:29 t-filexfer
1234567890 123 12345678 12345678 12345678 123456789012
The first line is sample output, the second is the length of the field.
The fields are: permissions, link count, user owner, group owner,
size in bytes, modification time.
attrs is a dictionary in the format of the attrs argument to openFile.
@param path: the directory to open.
"""
def getAttrs(path, followLinks):
"""
Return the attributes for the given path.
This method returns a dictionary in the same format as the attrs
argument to openFile or a Deferred that is called back with same.
@param path: the path to return attributes for as a string.
@param followLinks: a boolean. If it is True, follow symbolic links
and return attributes for the real path at the base. If it is False,
return attributes for the specified path.
"""
def setAttrs(path, attrs):
"""
Set the attributes for the path.
This method returns when the attributes are set or a Deferred that is
called back when they are.
@param path: the path to set attributes for as a string.
@param attrs: a dictionary in the same format as the attrs argument to
L{openFile}.
"""
def readLink(path):
"""
Find the root of a set of symbolic links.
This method returns the target of the link, or a Deferred that
returns the same.
@param path: the path of the symlink to read.
"""
def makeLink(linkPath, targetPath):
"""
Create a symbolic link.
This method returns when the link is made, or a Deferred that
returns the same.
@param linkPath: the pathname of the symlink as a string.
@param targetPath: the path of the target of the link as a string.
"""
def realPath(path):
"""
Convert any path to an absolute path.
This method returns the absolute path as a string, or a Deferred
that returns the same.
@param path: the path to convert as a string.
"""
def extendedRequest(extendedName, extendedData):
"""
This is the extension mechanism for SFTP. The other side can send us
arbitrary requests.
If we don't implement the request given by extendedName, raise
NotImplementedError.
The return value is a string, or a Deferred that will be called
back with a string.
@param extendedName: the name of the request as a string.
@param extendedData: the data the other side sent with the request,
as a string.
"""
class IKnownHostEntry(Interface):
"""
A L{IKnownHostEntry} is an entry in an OpenSSH-formatted C{known_hosts}
file.
@since: 8.2
"""
def matchesKey(key):
"""
Return True if this entry matches the given Key object, False
otherwise.
@param key: The key object to match against.
@type key: L{twisted.conch.ssh.Key}
"""
def matchesHost(hostname):
"""
Return True if this entry matches the given hostname, False otherwise.
Note that this does no name resolution; if you want to match an IP
address, you have to resolve it yourself, and pass it in as a dotted
quad string.
@param key: The hostname to match against.
@type key: L{str}
"""
def toString():
"""
@return: a serialized string representation of this entry, suitable for
inclusion in a known_hosts file. (Newline not included.)
@rtype: L{str}
"""
class ISFTPFile(Interface):
"""
This represents an open file on the server. An object adhering to this
interface should be returned from L{openFile}().
"""
def close():
"""
Close the file.
This method returns nothing if the close succeeds immediately, or a
Deferred that is called back when the close succeeds.
"""
def readChunk(offset, length):
"""
Read from the file.
If EOF is reached before any data is read, raise EOFError.
This method returns the data as a string, or a Deferred that is
called back with same.
@param offset: an integer that is the index to start from in the file.
@param length: the maximum length of data to return. The actual amount
returned may less than this. For normal disk files, however,
this should read the requested number (up to the end of the file).
"""
def writeChunk(offset, data):
"""
Write to the file.
This method returns when the write completes, or a Deferred that is
called when it completes.
@param offset: an integer that is the index to start from in the file.
@param data: a string that is the data to write.
"""
def getAttrs():
"""
Return the attributes for the file.
This method returns a dictionary in the same format as the attrs
argument to L{openFile} or a L{Deferred} that is called back with same.
"""
def setAttrs(attrs):
"""
Set the attributes for the file.
This method returns when the attributes are set or a Deferred that is
called back when they are.
@param attrs: a dictionary in the same format as the attrs argument to
L{openFile}.
"""

View File

@ -0,0 +1,75 @@
# -*- test-case-name: twisted.conch.test.test_cftp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import array
import stat
from time import time, strftime, localtime
# locale-independent month names to use instead of strftime's
_MONTH_NAMES = dict(zip(
range(1, 13),
"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()))
def lsLine(name, s):
"""
Build an 'ls' line for a file ('file' in its generic sense, it
can be of any type).
"""
mode = s.st_mode
perms = array.array('c', '-'*10)
ft = stat.S_IFMT(mode)
if stat.S_ISDIR(ft): perms[0] = 'd'
elif stat.S_ISCHR(ft): perms[0] = 'c'
elif stat.S_ISBLK(ft): perms[0] = 'b'
elif stat.S_ISREG(ft): perms[0] = '-'
elif stat.S_ISFIFO(ft): perms[0] = 'f'
elif stat.S_ISLNK(ft): perms[0] = 'l'
elif stat.S_ISSOCK(ft): perms[0] = 's'
else: perms[0] = '!'
# user
if mode&stat.S_IRUSR:perms[1] = 'r'
if mode&stat.S_IWUSR:perms[2] = 'w'
if mode&stat.S_IXUSR:perms[3] = 'x'
# group
if mode&stat.S_IRGRP:perms[4] = 'r'
if mode&stat.S_IWGRP:perms[5] = 'w'
if mode&stat.S_IXGRP:perms[6] = 'x'
# other
if mode&stat.S_IROTH:perms[7] = 'r'
if mode&stat.S_IWOTH:perms[8] = 'w'
if mode&stat.S_IXOTH:perms[9] = 'x'
# suid/sgid
if mode&stat.S_ISUID:
if perms[3] == 'x': perms[3] = 's'
else: perms[3] = 'S'
if mode&stat.S_ISGID:
if perms[6] == 'x': perms[6] = 's'
else: perms[6] = 'S'
lsresult = [
perms.tostring(),
str(s.st_nlink).rjust(5),
' ',
str(s.st_uid).ljust(9),
str(s.st_gid).ljust(9),
str(s.st_size).rjust(8),
' ',
]
# need to specify the month manually, as strftime depends on locale
ttup = localtime(s.st_mtime)
sixmonths = 60 * 60 * 24 * 7 * 26
if s.st_mtime + sixmonths < time(): # last edited more than 6mo ago
strtime = strftime("%%s %d %Y ", ttup)
else:
strtime = strftime("%%s %d %H:%M ", ttup)
lsresult.append(strtime % (_MONTH_NAMES[ttup[1]],))
lsresult.append(name)
return ''.join(lsresult)
__all__ = ['lsLine']

View File

@ -0,0 +1,340 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Line-input oriented interactive interpreter loop.
Provides classes for handling Python source input and arbitrary output
interactively from a Twisted application. Also included is syntax coloring
code with support for VT102 terminals, control code handling (^C, ^D, ^Q),
and reasonable handling of Deferreds.
@author: Jp Calderone
"""
import code, sys, StringIO, tokenize
from twisted.conch import recvline
from twisted.internet import defer
from twisted.python.htmlizer import TokenPrinter
class FileWrapper:
"""Minimal write-file-like object.
Writes are translated into addOutput calls on an object passed to
__init__. Newlines are also converted from network to local style.
"""
softspace = 0
state = 'normal'
def __init__(self, o):
self.o = o
def flush(self):
pass
def write(self, data):
self.o.addOutput(data.replace('\r\n', '\n'))
def writelines(self, lines):
self.write(''.join(lines))
class ManholeInterpreter(code.InteractiveInterpreter):
"""Interactive Interpreter with special output and Deferred support.
Aside from the features provided by L{code.InteractiveInterpreter}, this
class captures sys.stdout output and redirects it to the appropriate
location (the Manhole protocol instance). It also treats Deferreds
which reach the top-level specially: each is formatted to the user with
a unique identifier and a new callback and errback added to it, each of
which will format the unique identifier and the result with which the
Deferred fires and then pass it on to the next participant in the
callback chain.
"""
numDeferreds = 0
def __init__(self, handler, locals=None, filename="<console>"):
code.InteractiveInterpreter.__init__(self, locals)
self._pendingDeferreds = {}
self.handler = handler
self.filename = filename
self.resetBuffer()
def resetBuffer(self):
"""Reset the input buffer."""
self.buffer = []
def push(self, line):
"""Push a line to the interpreter.
The line should not have a trailing newline; it may have
internal newlines. The line is appended to a buffer and the
interpreter's runsource() method is called with the
concatenated contents of the buffer as source. If this
indicates that the command was executed or invalid, the buffer
is reset; otherwise, the command is incomplete, and the buffer
is left as it was after the line was appended. The return
value is 1 if more input is required, 0 if the line was dealt
with in some way (this is the same as runsource()).
"""
self.buffer.append(line)
source = "\n".join(self.buffer)
more = self.runsource(source, self.filename)
if not more:
self.resetBuffer()
return more
def runcode(self, *a, **kw):
orighook, sys.displayhook = sys.displayhook, self.displayhook
try:
origout, sys.stdout = sys.stdout, FileWrapper(self.handler)
try:
code.InteractiveInterpreter.runcode(self, *a, **kw)
finally:
sys.stdout = origout
finally:
sys.displayhook = orighook
def displayhook(self, obj):
self.locals['_'] = obj
if isinstance(obj, defer.Deferred):
# XXX Ick, where is my "hasFired()" interface?
if hasattr(obj, "result"):
self.write(repr(obj))
elif id(obj) in self._pendingDeferreds:
self.write("<Deferred #%d>" % (self._pendingDeferreds[id(obj)][0],))
else:
d = self._pendingDeferreds
k = self.numDeferreds
d[id(obj)] = (k, obj)
self.numDeferreds += 1
obj.addCallbacks(self._cbDisplayDeferred, self._ebDisplayDeferred,
callbackArgs=(k, obj), errbackArgs=(k, obj))
self.write("<Deferred #%d>" % (k,))
elif obj is not None:
self.write(repr(obj))
def _cbDisplayDeferred(self, result, k, obj):
self.write("Deferred #%d called back: %r" % (k, result), True)
del self._pendingDeferreds[id(obj)]
return result
def _ebDisplayDeferred(self, failure, k, obj):
self.write("Deferred #%d failed: %r" % (k, failure.getErrorMessage()), True)
del self._pendingDeferreds[id(obj)]
return failure
def write(self, data, async=False):
self.handler.addOutput(data, async)
CTRL_C = '\x03'
CTRL_D = '\x04'
CTRL_BACKSLASH = '\x1c'
CTRL_L = '\x0c'
CTRL_A = '\x01'
CTRL_E = '\x05'
class Manhole(recvline.HistoricRecvLine):
"""Mediator between a fancy line source and an interactive interpreter.
This accepts lines from its transport and passes them on to a
L{ManholeInterpreter}. Control commands (^C, ^D, ^\) are also handled
with something approximating their normal terminal-mode behavior. It
can optionally be constructed with a dict which will be used as the
local namespace for any code executed.
"""
namespace = None
def __init__(self, namespace=None):
recvline.HistoricRecvLine.__init__(self)
if namespace is not None:
self.namespace = namespace.copy()
def connectionMade(self):
recvline.HistoricRecvLine.connectionMade(self)
self.interpreter = ManholeInterpreter(self, self.namespace)
self.keyHandlers[CTRL_C] = self.handle_INT
self.keyHandlers[CTRL_D] = self.handle_EOF
self.keyHandlers[CTRL_L] = self.handle_FF
self.keyHandlers[CTRL_A] = self.handle_HOME
self.keyHandlers[CTRL_E] = self.handle_END
self.keyHandlers[CTRL_BACKSLASH] = self.handle_QUIT
def handle_INT(self):
"""
Handle ^C as an interrupt keystroke by resetting the current input
variables to their initial state.
"""
self.pn = 0
self.lineBuffer = []
self.lineBufferIndex = 0
self.interpreter.resetBuffer()
self.terminal.nextLine()
self.terminal.write("KeyboardInterrupt")
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
def handle_EOF(self):
if self.lineBuffer:
self.terminal.write('\a')
else:
self.handle_QUIT()
def handle_FF(self):
"""
Handle a 'form feed' byte - generally used to request a screen
refresh/redraw.
"""
self.terminal.eraseDisplay()
self.terminal.cursorHome()
self.drawInputLine()
def handle_QUIT(self):
self.terminal.loseConnection()
def _needsNewline(self):
w = self.terminal.lastWrite
return not w.endswith('\n') and not w.endswith('\x1bE')
def addOutput(self, bytes, async=False):
if async:
self.terminal.eraseLine()
self.terminal.cursorBackward(len(self.lineBuffer) + len(self.ps[self.pn]))
self.terminal.write(bytes)
if async:
if self._needsNewline():
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
if self.lineBuffer:
oldBuffer = self.lineBuffer
self.lineBuffer = []
self.lineBufferIndex = 0
self._deliverBuffer(oldBuffer)
def lineReceived(self, line):
more = self.interpreter.push(line)
self.pn = bool(more)
if self._needsNewline():
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
class VT102Writer:
"""Colorizer for Python tokens.
A series of tokens are written to instances of this object. Each is
colored in a particular way. The final line of the result of this is
generally added to the output.
"""
typeToColor = {
'identifier': '\x1b[31m',
'keyword': '\x1b[32m',
'parameter': '\x1b[33m',
'variable': '\x1b[1;33m',
'string': '\x1b[35m',
'number': '\x1b[36m',
'op': '\x1b[37m'}
normalColor = '\x1b[0m'
def __init__(self):
self.written = []
def color(self, type):
r = self.typeToColor.get(type, '')
return r
def write(self, token, type=None):
if token and token != '\r':
c = self.color(type)
if c:
self.written.append(c)
self.written.append(token)
if c:
self.written.append(self.normalColor)
def __str__(self):
s = ''.join(self.written)
return s.strip('\n').splitlines()[-1]
def lastColorizedLine(source):
"""Tokenize and colorize the given Python source.
Returns a VT102-format colorized version of the last line of C{source}.
"""
w = VT102Writer()
p = TokenPrinter(w.write).printtoken
s = StringIO.StringIO(source)
tokenize.tokenize(s.readline, p)
return str(w)
class ColoredManhole(Manhole):
"""A REPL which syntax colors input as users type it.
"""
def getSource(self):
"""Return a string containing the currently entered source.
This is only the code which will be considered for execution
next.
"""
return ('\n'.join(self.interpreter.buffer) +
'\n' +
''.join(self.lineBuffer))
def characterReceived(self, ch, moreCharactersComing):
if self.mode == 'insert':
self.lineBuffer.insert(self.lineBufferIndex, ch)
else:
self.lineBuffer[self.lineBufferIndex:self.lineBufferIndex+1] = [ch]
self.lineBufferIndex += 1
if moreCharactersComing:
# Skip it all, we'll get called with another character in
# like 2 femtoseconds.
return
if ch == ' ':
# Don't bother to try to color whitespace
self.terminal.write(ch)
return
source = self.getSource()
# Try to write some junk
try:
coloredLine = lastColorizedLine(source)
except tokenize.TokenError:
# We couldn't do it. Strange. Oh well, just add the character.
self.terminal.write(ch)
else:
# Success! Clear the source on this line.
self.terminal.eraseLine()
self.terminal.cursorBackward(len(self.lineBuffer) + len(self.ps[self.pn]) - 1)
# And write a new, colorized one.
self.terminal.write(self.ps[self.pn] + coloredLine)
# And move the cursor to where it belongs
n = len(self.lineBuffer) - self.lineBufferIndex
if n:
self.terminal.cursorBackward(n)

View File

@ -0,0 +1,146 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
insults/SSH integration support.
@author: Jp Calderone
"""
from zope.interface import implementer
from twisted.conch import avatar, interfaces as iconch, error as econch
from twisted.conch.ssh import factory, keys, session
from twisted.python import components
from twisted.conch.insults import insults
class _Glue:
"""A feeble class for making one attribute look like another.
This should be replaced with a real class at some point, probably.
Try not to write new code that uses it.
"""
def __init__(self, **kw):
self.__dict__.update(kw)
def __getattr__(self, name):
raise AttributeError(self.name, "has no attribute", name)
class TerminalSessionTransport:
def __init__(self, proto, chainedProtocol, avatar, width, height):
self.proto = proto
self.avatar = avatar
self.chainedProtocol = chainedProtocol
protoSession = self.proto.session
self.proto.makeConnection(
_Glue(write=self.chainedProtocol.dataReceived,
loseConnection=lambda: avatar.conn.sendClose(protoSession),
name="SSH Proto Transport"))
def loseConnection():
self.proto.loseConnection()
self.chainedProtocol.makeConnection(
_Glue(write=self.proto.write,
loseConnection=loseConnection,
name="Chained Proto Transport"))
# XXX TODO
# chainedProtocol is supposed to be an ITerminalTransport,
# maybe. That means perhaps its terminalProtocol attribute is
# an ITerminalProtocol, it could be. So calling terminalSize
# on that should do the right thing But it'd be nice to clean
# this bit up.
self.chainedProtocol.terminalProtocol.terminalSize(width, height)
@implementer(iconch.ISession)
class TerminalSession(components.Adapter):
transportFactory = TerminalSessionTransport
chainedProtocolFactory = insults.ServerProtocol
def getPty(self, term, windowSize, attrs):
self.height, self.width = windowSize[:2]
def openShell(self, proto):
self.transportFactory(
proto, self.chainedProtocolFactory(),
iconch.IConchUser(self.original),
self.width, self.height)
def execCommand(self, proto, cmd):
raise econch.ConchError("Cannot execute commands")
def closed(self):
pass
class TerminalUser(avatar.ConchUser, components.Adapter):
def __init__(self, original, avatarId):
components.Adapter.__init__(self, original)
avatar.ConchUser.__init__(self)
self.channelLookup['session'] = session.SSHSession
class TerminalRealm:
userFactory = TerminalUser
sessionFactory = TerminalSession
transportFactory = TerminalSessionTransport
chainedProtocolFactory = insults.ServerProtocol
def _getAvatar(self, avatarId):
comp = components.Componentized()
user = self.userFactory(comp, avatarId)
sess = self.sessionFactory(comp)
sess.transportFactory = self.transportFactory
sess.chainedProtocolFactory = self.chainedProtocolFactory
comp.setComponent(iconch.IConchUser, user)
comp.setComponent(iconch.ISession, sess)
return user
def __init__(self, transportFactory=None):
if transportFactory is not None:
self.transportFactory = transportFactory
def requestAvatar(self, avatarId, mind, *interfaces):
for i in interfaces:
if i is iconch.IConchUser:
return (iconch.IConchUser,
self._getAvatar(avatarId),
lambda: None)
raise NotImplementedError()
class ConchFactory(factory.SSHFactory):
publicKey = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJSkbh/C+BR3utDS555mV'
publicKeys = {
'ssh-rsa' : keys.Key.fromString(publicKey)
}
del publicKey
privateKey = """-----BEGIN RSA PRIVATE KEY-----
MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
-----END RSA PRIVATE KEY-----"""
privateKeys = {
'ssh-rsa' : keys.Key.fromString(privateKey)
}
del privateKey
def __init__(self, portal):
self.portal = portal

View File

@ -0,0 +1,123 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
TAP plugin for creating telnet- and ssh-accessible manhole servers.
@author: Jp Calderone
"""
from zope.interface import implementer
from twisted.internet import protocol
from twisted.application import service, strports
from twisted.cred import portal, checkers
from twisted.python import usage
from twisted.conch.insults import insults
from twisted.conch import manhole, manhole_ssh, telnet
class makeTelnetProtocol:
def __init__(self, portal):
self.portal = portal
def __call__(self):
auth = telnet.AuthenticatingTelnetProtocol
args = (self.portal,)
return telnet.TelnetTransport(auth, *args)
class chainedProtocolFactory:
def __init__(self, namespace):
self.namespace = namespace
def __call__(self):
return insults.ServerProtocol(manhole.ColoredManhole, self.namespace)
@implementer(portal.IRealm)
class _StupidRealm:
def __init__(self, proto, *a, **kw):
self.protocolFactory = proto
self.protocolArgs = a
self.protocolKwArgs = kw
def requestAvatar(self, avatarId, *interfaces):
if telnet.ITelnetProtocol in interfaces:
return (telnet.ITelnetProtocol,
self.protocolFactory(*self.protocolArgs, **self.protocolKwArgs),
lambda: None)
raise NotImplementedError()
class Options(usage.Options):
optParameters = [
["telnetPort", "t", None, "strports description of the address on which to listen for telnet connections"],
["sshPort", "s", None, "strports description of the address on which to listen for ssh connections"],
["passwd", "p", "/etc/passwd", "name of a passwd(5)-format username/password file"]]
def __init__(self):
usage.Options.__init__(self)
self['namespace'] = None
def postOptions(self):
if self['telnetPort'] is None and self['sshPort'] is None:
raise usage.UsageError("At least one of --telnetPort and --sshPort must be specified")
def makeService(options):
"""Create a manhole server service.
@type options: C{dict}
@param options: A mapping describing the configuration of
the desired service. Recognized key/value pairs are::
"telnetPort": strports description of the address on which
to listen for telnet connections. If None,
no telnet service will be started.
"sshPort": strports description of the address on which to
listen for ssh connections. If None, no ssh
service will be started.
"namespace": dictionary containing desired initial locals
for manhole connections. If None, an empty
dictionary will be used.
"passwd": Name of a passwd(5)-format username/password file.
@rtype: L{twisted.application.service.IService}
@return: A manhole service.
"""
svc = service.MultiService()
namespace = options['namespace']
if namespace is None:
namespace = {}
checker = checkers.FilePasswordDB(options['passwd'])
if options['telnetPort']:
telnetRealm = _StupidRealm(telnet.TelnetBootstrapProtocol,
insults.ServerProtocol,
manhole.ColoredManhole,
namespace)
telnetPortal = portal.Portal(telnetRealm, [checker])
telnetFactory = protocol.ServerFactory()
telnetFactory.protocol = makeTelnetProtocol(telnetPortal)
telnetService = strports.service(options['telnetPort'],
telnetFactory)
telnetService.setServiceParent(svc)
if options['sshPort']:
sshRealm = manhole_ssh.TerminalRealm()
sshRealm.chainedProtocolFactory = chainedProtocolFactory(namespace)
sshPortal = portal.Portal(sshRealm, [checker])
sshFactory = manhole_ssh.ConchFactory(sshPortal)
sshService = strports.service(options['sshPort'],
sshFactory)
sshService.setServiceParent(svc)
return svc

View File

@ -0,0 +1,49 @@
# -*- test-case-name: twisted.conch.test.test_mixin -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Experimental optimization
This module provides a single mixin class which allows protocols to
collapse numerous small writes into a single larger one.
@author: Jp Calderone
"""
from twisted.internet import reactor
class BufferingMixin:
"""Mixin which adds write buffering.
"""
_delayedWriteCall = None
bytes = None
DELAY = 0.0
def schedule(self):
return reactor.callLater(self.DELAY, self.flush)
def reschedule(self, token):
token.reset(self.DELAY)
def write(self, bytes):
"""Buffer some bytes to be written soon.
Every call to this function delays the real write by C{self.DELAY}
seconds. When the delay expires, all collected bytes are written
to the underlying transport using L{ITransport.writeSequence}.
"""
if self._delayedWriteCall is None:
self.bytes = []
self._delayedWriteCall = self.schedule()
else:
self.reschedule(self._delayedWriteCall)
self.bytes.append(bytes)
def flush(self):
"""Flush the buffer immediately.
"""
self._delayedWriteCall = None
self.transport.writeSequence(self.bytes)
self.bytes = None

View File

@ -0,0 +1,11 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Support for OpenSSH configuration files.
Maintainer: Paul Swartz
"""

View File

@ -0,0 +1,73 @@
# -*- test-case-name: twisted.conch.test.test_openssh_compat -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Factory for reading openssh configuration files: public keys, private keys, and
moduli file.
"""
import os, errno
from twisted.python import log
from twisted.python.util import runAsEffectiveUser
from twisted.conch.ssh import keys, factory, common
from twisted.conch.openssh_compat import primes
class OpenSSHFactory(factory.SSHFactory):
dataRoot = '/usr/local/etc'
moduliRoot = '/usr/local/etc' # for openbsd which puts moduli in a different
# directory from keys
def getPublicKeys(self):
"""
Return the server public keys.
"""
ks = {}
for filename in os.listdir(self.dataRoot):
if filename[:9] == 'ssh_host_' and filename[-8:]=='_key.pub':
try:
k = keys.Key.fromFile(
os.path.join(self.dataRoot, filename))
t = common.getNS(k.blob())[0]
ks[t] = k
except Exception, e:
log.msg('bad public key file %s: %s' % (filename, e))
return ks
def getPrivateKeys(self):
"""
Return the server private keys.
"""
privateKeys = {}
for filename in os.listdir(self.dataRoot):
if filename[:9] == 'ssh_host_' and filename[-4:]=='_key':
fullPath = os.path.join(self.dataRoot, filename)
try:
key = keys.Key.fromFile(fullPath)
except IOError, e:
if e.errno == errno.EACCES:
# Not allowed, let's switch to root
key = runAsEffectiveUser(0, 0, keys.Key.fromFile, fullPath)
keyType = keys.objectType(key.keyObject)
privateKeys[keyType] = key
else:
raise
except Exception, e:
log.msg('bad private key file %s: %s' % (filename, e))
else:
keyType = keys.objectType(key.keyObject)
privateKeys[keyType] = key
return privateKeys
def getPrimes(self):
try:
return primes.parseModuliFile(self.moduliRoot+'/moduli')
except IOError:
return None

View File

@ -0,0 +1,26 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Parsing for the moduli file, which contains Diffie-Hellman prime groups.
Maintainer: Paul Swartz
"""
def parseModuliFile(filename):
lines = open(filename).readlines()
primes = {}
for l in lines:
l = l.strip()
if not l or l[0]=='#':
continue
tim, typ, tst, tri, size, gen, mod = l.split()
size = int(size) + 1
gen = long(gen)
mod = long(mod, 16)
if not primes.has_key(size):
primes[size] = []
primes[size].append((gen, mod))
return primes

View File

@ -0,0 +1,331 @@
# -*- test-case-name: twisted.conch.test.test_recvline -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Basic line editing support.
@author: Jp Calderone
"""
import string
from zope.interface import implementer
from twisted.conch.insults import insults, helper
from twisted.python import log, reflect
_counters = {}
class Logging(object):
"""Wrapper which logs attribute lookups.
This was useful in debugging something, I guess. I forget what.
It can probably be deleted or moved somewhere more appropriate.
Nothing special going on here, really.
"""
def __init__(self, original):
self.original = original
key = reflect.qual(original.__class__)
count = _counters.get(key, 0)
_counters[key] = count + 1
self._logFile = file(key + '-' + str(count), 'w')
def __str__(self):
return str(super(Logging, self).__getattribute__('original'))
def __repr__(self):
return repr(super(Logging, self).__getattribute__('original'))
def __getattribute__(self, name):
original = super(Logging, self).__getattribute__('original')
logFile = super(Logging, self).__getattribute__('_logFile')
logFile.write(name + '\n')
return getattr(original, name)
@implementer(insults.ITerminalTransport)
class TransportSequence(object):
"""An L{ITerminalTransport} implementation which forwards calls to
one or more other L{ITerminalTransport}s.
This is a cheap way for servers to keep track of the state they
expect the client to see, since all terminal manipulations can be
send to the real client and to a terminal emulator that lives in
the server process.
"""
for keyID in ('UP_ARROW', 'DOWN_ARROW', 'RIGHT_ARROW', 'LEFT_ARROW',
'HOME', 'INSERT', 'DELETE', 'END', 'PGUP', 'PGDN',
'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9',
'F10', 'F11', 'F12'):
exec '%s = object()' % (keyID,)
TAB = '\t'
BACKSPACE = '\x7f'
def __init__(self, *transports):
assert transports, "Cannot construct a TransportSequence with no transports"
self.transports = transports
for method in insults.ITerminalTransport:
exec """\
def %s(self, *a, **kw):
for tpt in self.transports:
result = tpt.%s(*a, **kw)
return result
""" % (method, method)
class LocalTerminalBufferMixin(object):
"""A mixin for RecvLine subclasses which records the state of the terminal.
This is accomplished by performing all L{ITerminalTransport} operations on both
the transport passed to makeConnection and an instance of helper.TerminalBuffer.
@ivar terminalCopy: A L{helper.TerminalBuffer} instance which efforts
will be made to keep up to date with the actual terminal
associated with this protocol instance.
"""
def makeConnection(self, transport):
self.terminalCopy = helper.TerminalBuffer()
self.terminalCopy.connectionMade()
return super(LocalTerminalBufferMixin, self).makeConnection(
TransportSequence(transport, self.terminalCopy))
def __str__(self):
return str(self.terminalCopy)
class RecvLine(insults.TerminalProtocol):
"""L{TerminalProtocol} which adds line editing features.
Clients will be prompted for lines of input with all the usual
features: character echoing, left and right arrow support for
moving the cursor to different areas of the line buffer, backspace
and delete for removing characters, and insert for toggling
between typeover and insert mode. Tabs will be expanded to enough
spaces to move the cursor to the next tabstop (every four
characters by default). Enter causes the line buffer to be
cleared and the line to be passed to the lineReceived() method
which, by default, does nothing. Subclasses are responsible for
redrawing the input prompt (this will probably change).
"""
width = 80
height = 24
TABSTOP = 4
ps = ('>>> ', '... ')
pn = 0
_printableChars = set(string.printable)
def connectionMade(self):
# A list containing the characters making up the current line
self.lineBuffer = []
# A zero-based (wtf else?) index into self.lineBuffer.
# Indicates the current cursor position.
self.lineBufferIndex = 0
t = self.terminal
# A map of keyIDs to bound instance methods.
self.keyHandlers = {
t.LEFT_ARROW: self.handle_LEFT,
t.RIGHT_ARROW: self.handle_RIGHT,
t.TAB: self.handle_TAB,
# Both of these should not be necessary, but figuring out
# which is necessary is a huge hassle.
'\r': self.handle_RETURN,
'\n': self.handle_RETURN,
t.BACKSPACE: self.handle_BACKSPACE,
t.DELETE: self.handle_DELETE,
t.INSERT: self.handle_INSERT,
t.HOME: self.handle_HOME,
t.END: self.handle_END}
self.initializeScreen()
def initializeScreen(self):
# Hmm, state sucks. Oh well.
# For now we will just take over the whole terminal.
self.terminal.reset()
self.terminal.write(self.ps[self.pn])
# XXX Note: I would prefer to default to starting in insert
# mode, however this does not seem to actually work! I do not
# know why. This is probably of interest to implementors
# subclassing RecvLine.
# XXX XXX Note: But the unit tests all expect the initial mode
# to be insert right now. Fuck, there needs to be a way to
# query the current mode or something.
# self.setTypeoverMode()
self.setInsertMode()
def currentLineBuffer(self):
s = ''.join(self.lineBuffer)
return s[:self.lineBufferIndex], s[self.lineBufferIndex:]
def setInsertMode(self):
self.mode = 'insert'
self.terminal.setModes([insults.modes.IRM])
def setTypeoverMode(self):
self.mode = 'typeover'
self.terminal.resetModes([insults.modes.IRM])
def drawInputLine(self):
"""
Write a line containing the current input prompt and the current line
buffer at the current cursor position.
"""
self.terminal.write(self.ps[self.pn] + ''.join(self.lineBuffer))
def terminalSize(self, width, height):
# XXX - Clear the previous input line, redraw it at the new
# cursor position
self.terminal.eraseDisplay()
self.terminal.cursorHome()
self.width = width
self.height = height
self.drawInputLine()
def unhandledControlSequence(self, seq):
pass
def keystrokeReceived(self, keyID, modifier):
m = self.keyHandlers.get(keyID)
if m is not None:
m()
elif keyID in self._printableChars:
self.characterReceived(keyID, False)
else:
log.msg("Received unhandled keyID: %r" % (keyID,))
def characterReceived(self, ch, moreCharactersComing):
if self.mode == 'insert':
self.lineBuffer.insert(self.lineBufferIndex, ch)
else:
self.lineBuffer[self.lineBufferIndex:self.lineBufferIndex+1] = [ch]
self.lineBufferIndex += 1
self.terminal.write(ch)
def handle_TAB(self):
n = self.TABSTOP - (len(self.lineBuffer) % self.TABSTOP)
self.terminal.cursorForward(n)
self.lineBufferIndex += n
self.lineBuffer.extend(' ' * n)
def handle_LEFT(self):
if self.lineBufferIndex > 0:
self.lineBufferIndex -= 1
self.terminal.cursorBackward()
def handle_RIGHT(self):
if self.lineBufferIndex < len(self.lineBuffer):
self.lineBufferIndex += 1
self.terminal.cursorForward()
def handle_HOME(self):
if self.lineBufferIndex:
self.terminal.cursorBackward(self.lineBufferIndex)
self.lineBufferIndex = 0
def handle_END(self):
offset = len(self.lineBuffer) - self.lineBufferIndex
if offset:
self.terminal.cursorForward(offset)
self.lineBufferIndex = len(self.lineBuffer)
def handle_BACKSPACE(self):
if self.lineBufferIndex > 0:
self.lineBufferIndex -= 1
del self.lineBuffer[self.lineBufferIndex]
self.terminal.cursorBackward()
self.terminal.deleteCharacter()
def handle_DELETE(self):
if self.lineBufferIndex < len(self.lineBuffer):
del self.lineBuffer[self.lineBufferIndex]
self.terminal.deleteCharacter()
def handle_RETURN(self):
line = ''.join(self.lineBuffer)
self.lineBuffer = []
self.lineBufferIndex = 0
self.terminal.nextLine()
self.lineReceived(line)
def handle_INSERT(self):
assert self.mode in ('typeover', 'insert')
if self.mode == 'typeover':
self.setInsertMode()
else:
self.setTypeoverMode()
def lineReceived(self, line):
pass
class HistoricRecvLine(RecvLine):
"""L{TerminalProtocol} which adds both basic line-editing features and input history.
Everything supported by L{RecvLine} is also supported by this class. In addition, the
up and down arrows traverse the input history. Each received line is automatically
added to the end of the input history.
"""
def connectionMade(self):
RecvLine.connectionMade(self)
self.historyLines = []
self.historyPosition = 0
t = self.terminal
self.keyHandlers.update({t.UP_ARROW: self.handle_UP,
t.DOWN_ARROW: self.handle_DOWN})
def currentHistoryBuffer(self):
b = tuple(self.historyLines)
return b[:self.historyPosition], b[self.historyPosition:]
def _deliverBuffer(self, buf):
if buf:
for ch in buf[:-1]:
self.characterReceived(ch, True)
self.characterReceived(buf[-1], False)
def handle_UP(self):
if self.lineBuffer and self.historyPosition == len(self.historyLines):
self.historyLines.append(self.lineBuffer)
if self.historyPosition > 0:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition -= 1
self.lineBuffer = []
self._deliverBuffer(self.historyLines[self.historyPosition])
def handle_DOWN(self):
if self.historyPosition < len(self.historyLines) - 1:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition += 1
self.lineBuffer = []
self._deliverBuffer(self.historyLines[self.historyPosition])
else:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition = len(self.historyLines)
self.lineBuffer = []
self.lineBufferIndex = 0
def handle_RETURN(self):
if self.lineBuffer:
self.historyLines.append(''.join(self.lineBuffer))
self.historyPosition = len(self.historyLines)
return RecvLine.handle_RETURN(self)

View File

@ -0,0 +1 @@
'conch scripts'

View File

@ -0,0 +1,834 @@
# -*- test-case-name: twisted.conch.test.test_cftp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the I{cftp} command.
"""
import os, sys, getpass, struct, tty, fcntl, stat
import fnmatch, pwd, glob
from twisted.conch.client import connect, default, options
from twisted.conch.ssh import connection, common
from twisted.conch.ssh import channel, filetransfer
from twisted.protocols import basic
from twisted.internet import reactor, stdio, defer, utils
from twisted.python import log, usage, failure
class ClientOptions(options.ConchOptions):
synopsis = """Usage: cftp [options] [user@]host
cftp [options] [user@]host[:dir[/]]
cftp [options] [user@]host[:file [localfile]]
"""
longdesc = ("cftp is a client for logging into a remote machine and "
"executing commands to send and receive file information")
optParameters = [
['buffersize', 'B', 32768, 'Size of the buffer to use for sending/receiving.'],
['batchfile', 'b', None, 'File to read commands from, or \'-\' for stdin.'],
['requests', 'R', 5, 'Number of requests to make before waiting for a reply.'],
['subsystem', 's', 'sftp', 'Subsystem/server program to connect to.']]
compData = usage.Completions(
descriptions={
"buffersize": "Size of send/receive buffer (default: 32768)"},
extraActions=[usage.CompleteUserAtHost(),
usage.CompleteFiles(descr="local file")])
def parseArgs(self, host, localPath=None):
self['remotePath'] = ''
if ':' in host:
host, self['remotePath'] = host.split(':', 1)
self['remotePath'].rstrip('/')
self['host'] = host
self['localPath'] = localPath
def run():
# import hotshot
# prof = hotshot.Profile('cftp.prof')
# prof.start()
args = sys.argv[1:]
if '-l' in args: # cvs is an idiot
i = args.index('-l')
args = args[i:i+2]+args
del args[i+2:i+4]
options = ClientOptions()
try:
options.parseOptions(args)
except usage.UsageError, u:
print 'ERROR: %s' % u
sys.exit(1)
if options['log']:
realout = sys.stdout
log.startLogging(sys.stderr)
sys.stdout = realout
else:
log.discardLogs()
doConnect(options)
reactor.run()
# prof.stop()
# prof.close()
def handleError():
global exitStatus
exitStatus = 2
try:
reactor.stop()
except: pass
log.err(failure.Failure())
raise
def doConnect(options):
# log.deferr = handleError # HACK
if '@' in options['host']:
options['user'], options['host'] = options['host'].split('@',1)
host = options['host']
if not options['user']:
options['user'] = getpass.getuser()
if not options['port']:
options['port'] = 22
else:
options['port'] = int(options['port'])
host = options['host']
port = options['port']
conn = SSHConnection()
conn.options = options
vhk = default.verifyHostKey
uao = default.SSHUserAuthClient(options['user'], options, conn)
connect.connect(host, port, options, vhk, uao).addErrback(_ebExit)
def _ebExit(f):
#global exitStatus
if hasattr(f.value, 'value'):
s = f.value.value
else:
s = str(f)
print s
#exitStatus = "conch: exiting with error %s" % f
try:
reactor.stop()
except: pass
def _ignore(*args): pass
class FileWrapper:
def __init__(self, f):
self.f = f
self.total = 0.0
f.seek(0, 2) # seek to the end
self.size = f.tell()
def __getattr__(self, attr):
return getattr(self.f, attr)
class StdioClient(basic.LineReceiver):
_pwd = pwd
ps = 'cftp> '
delimiter = '\n'
reactor = reactor
def __init__(self, client, f = None):
self.client = client
self.currentDirectory = ''
self.file = f
self.useProgressBar = (not f and 1) or 0
def connectionMade(self):
self.client.realPath('').addCallback(self._cbSetCurDir)
def _cbSetCurDir(self, path):
self.currentDirectory = path
self._newLine()
def lineReceived(self, line):
if self.client.transport.localClosed:
return
log.msg('got line %s' % repr(line))
line = line.lstrip()
if not line:
self._newLine()
return
if self.file and line.startswith('-'):
self.ignoreErrors = 1
line = line[1:]
else:
self.ignoreErrors = 0
d = self._dispatchCommand(line)
if d is not None:
d.addCallback(self._cbCommand)
d.addErrback(self._ebCommand)
def _dispatchCommand(self, line):
if ' ' in line:
command, rest = line.split(' ', 1)
rest = rest.lstrip()
else:
command, rest = line, ''
if command.startswith('!'): # command
f = self.cmd_EXEC
rest = (command[1:] + ' ' + rest).strip()
else:
command = command.upper()
log.msg('looking up cmd %s' % command)
f = getattr(self, 'cmd_%s' % command, None)
if f is not None:
return defer.maybeDeferred(f, rest)
else:
self._ebCommand(failure.Failure(NotImplementedError(
"No command called `%s'" % command)))
self._newLine()
def _printFailure(self, f):
log.msg(f)
e = f.trap(NotImplementedError, filetransfer.SFTPError, OSError, IOError)
if e == NotImplementedError:
self.transport.write(self.cmd_HELP(''))
elif e == filetransfer.SFTPError:
self.transport.write("remote error %i: %s\n" %
(f.value.code, f.value.message))
elif e in (OSError, IOError):
self.transport.write("local error %i: %s\n" %
(f.value.errno, f.value.strerror))
def _newLine(self):
if self.client.transport.localClosed:
return
self.transport.write(self.ps)
self.ignoreErrors = 0
if self.file:
l = self.file.readline()
if not l:
self.client.transport.loseConnection()
else:
self.transport.write(l)
self.lineReceived(l.strip())
def _cbCommand(self, result):
if result is not None:
self.transport.write(result)
if not result.endswith('\n'):
self.transport.write('\n')
self._newLine()
def _ebCommand(self, f):
self._printFailure(f)
if self.file and not self.ignoreErrors:
self.client.transport.loseConnection()
self._newLine()
def cmd_CD(self, path):
path, rest = self._getFilename(path)
if not path.endswith('/'):
path += '/'
newPath = path and os.path.join(self.currentDirectory, path) or ''
d = self.client.openDirectory(newPath)
d.addCallback(self._cbCd)
d.addErrback(self._ebCommand)
return d
def _cbCd(self, directory):
directory.close()
d = self.client.realPath(directory.name)
d.addCallback(self._cbCurDir)
return d
def _cbCurDir(self, path):
self.currentDirectory = path
def cmd_CHGRP(self, rest):
grp, rest = rest.split(None, 1)
path, rest = self._getFilename(rest)
grp = int(grp)
d = self.client.getAttrs(path)
d.addCallback(self._cbSetUsrGrp, path, grp=grp)
return d
def cmd_CHMOD(self, rest):
mod, rest = rest.split(None, 1)
path, rest = self._getFilename(rest)
mod = int(mod, 8)
d = self.client.setAttrs(path, {'permissions':mod})
d.addCallback(_ignore)
return d
def cmd_CHOWN(self, rest):
usr, rest = rest.split(None, 1)
path, rest = self._getFilename(rest)
usr = int(usr)
d = self.client.getAttrs(path)
d.addCallback(self._cbSetUsrGrp, path, usr=usr)
return d
def _cbSetUsrGrp(self, attrs, path, usr=None, grp=None):
new = {}
new['uid'] = (usr is not None) and usr or attrs['uid']
new['gid'] = (grp is not None) and grp or attrs['gid']
d = self.client.setAttrs(path, new)
d.addCallback(_ignore)
return d
def cmd_GET(self, rest):
remote, rest = self._getFilename(rest)
if '*' in remote or '?' in remote: # wildcard
if rest:
local, rest = self._getFilename(rest)
if not os.path.isdir(local):
return "Wildcard get with non-directory target."
else:
local = ''
d = self._remoteGlob(remote)
d.addCallback(self._cbGetMultiple, local)
return d
if rest:
local, rest = self._getFilename(rest)
else:
local = os.path.split(remote)[1]
log.msg((remote, local))
lf = file(local, 'w', 0)
path = os.path.join(self.currentDirectory, remote)
d = self.client.openFile(path, filetransfer.FXF_READ, {})
d.addCallback(self._cbGetOpenFile, lf)
d.addErrback(self._ebCloseLf, lf)
return d
def _cbGetMultiple(self, files, local):
#if self._useProgressBar: # one at a time
# XXX this can be optimized for times w/o progress bar
return self._cbGetMultipleNext(None, files, local)
def _cbGetMultipleNext(self, res, files, local):
if isinstance(res, failure.Failure):
self._printFailure(res)
elif res:
self.transport.write(res)
if not res.endswith('\n'):
self.transport.write('\n')
if not files:
return
f = files.pop(0)[0]
lf = file(os.path.join(local, os.path.split(f)[1]), 'w', 0)
path = os.path.join(self.currentDirectory, f)
d = self.client.openFile(path, filetransfer.FXF_READ, {})
d.addCallback(self._cbGetOpenFile, lf)
d.addErrback(self._ebCloseLf, lf)
d.addBoth(self._cbGetMultipleNext, files, local)
return d
def _ebCloseLf(self, f, lf):
lf.close()
return f
def _cbGetOpenFile(self, rf, lf):
return rf.getAttrs().addCallback(self._cbGetFileSize, rf, lf)
def _cbGetFileSize(self, attrs, rf, lf):
if not stat.S_ISREG(attrs['permissions']):
rf.close()
lf.close()
return "Can't get non-regular file: %s" % rf.name
rf.size = attrs['size']
bufferSize = self.client.transport.conn.options['buffersize']
numRequests = self.client.transport.conn.options['requests']
rf.total = 0.0
dList = []
chunks = []
startTime = self.reactor.seconds()
for i in range(numRequests):
d = self._cbGetRead('', rf, lf, chunks, 0, bufferSize, startTime)
dList.append(d)
dl = defer.DeferredList(dList, fireOnOneErrback=1)
dl.addCallback(self._cbGetDone, rf, lf)
return dl
def _getNextChunk(self, chunks):
end = 0
for chunk in chunks:
if end == 'eof':
return # nothing more to get
if end != chunk[0]:
i = chunks.index(chunk)
chunks.insert(i, (end, chunk[0]))
return (end, chunk[0] - end)
end = chunk[1]
bufSize = int(self.client.transport.conn.options['buffersize'])
chunks.append((end, end + bufSize))
return (end, bufSize)
def _cbGetRead(self, data, rf, lf, chunks, start, size, startTime):
if data and isinstance(data, failure.Failure):
log.msg('get read err: %s' % data)
reason = data
reason.trap(EOFError)
i = chunks.index((start, start + size))
del chunks[i]
chunks.insert(i, (start, 'eof'))
elif data:
log.msg('get read data: %i' % len(data))
lf.seek(start)
lf.write(data)
if len(data) != size:
log.msg('got less than we asked for: %i < %i' %
(len(data), size))
i = chunks.index((start, start + size))
del chunks[i]
chunks.insert(i, (start, start + len(data)))
rf.total += len(data)
if self.useProgressBar:
self._printProgressBar(rf, startTime)
chunk = self._getNextChunk(chunks)
if not chunk:
return
else:
start, length = chunk
log.msg('asking for %i -> %i' % (start, start+length))
d = rf.readChunk(start, length)
d.addBoth(self._cbGetRead, rf, lf, chunks, start, length, startTime)
return d
def _cbGetDone(self, ignored, rf, lf):
log.msg('get done')
rf.close()
lf.close()
if self.useProgressBar:
self.transport.write('\n')
return "Transferred %s to %s" % (rf.name, lf.name)
def cmd_PUT(self, rest):
local, rest = self._getFilename(rest)
if '*' in local or '?' in local: # wildcard
if rest:
remote, rest = self._getFilename(rest)
path = os.path.join(self.currentDirectory, remote)
d = self.client.getAttrs(path)
d.addCallback(self._cbPutTargetAttrs, remote, local)
return d
else:
remote = ''
files = glob.glob(local)
return self._cbPutMultipleNext(None, files, remote)
if rest:
remote, rest = self._getFilename(rest)
else:
remote = os.path.split(local)[1]
lf = file(local, 'r')
path = os.path.join(self.currentDirectory, remote)
flags = filetransfer.FXF_WRITE|filetransfer.FXF_CREAT|filetransfer.FXF_TRUNC
d = self.client.openFile(path, flags, {})
d.addCallback(self._cbPutOpenFile, lf)
d.addErrback(self._ebCloseLf, lf)
return d
def _cbPutTargetAttrs(self, attrs, path, local):
if not stat.S_ISDIR(attrs['permissions']):
return "Wildcard put with non-directory target."
# FIXME:7037:
# Check what `files` variable should do here.
return self._cbPutMultipleNext(None, files, path)
def _cbPutMultipleNext(self, res, files, path):
if isinstance(res, failure.Failure):
self._printFailure(res)
elif res:
self.transport.write(res)
if not res.endswith('\n'):
self.transport.write('\n')
f = None
while files and not f:
try:
f = files.pop(0)
lf = file(f, 'r')
except:
self._printFailure(failure.Failure())
f = None
if not f:
return
name = os.path.split(f)[1]
remote = os.path.join(self.currentDirectory, path, name)
log.msg((name, remote, path))
flags = filetransfer.FXF_WRITE|filetransfer.FXF_CREAT|filetransfer.FXF_TRUNC
d = self.client.openFile(remote, flags, {})
d.addCallback(self._cbPutOpenFile, lf)
d.addErrback(self._ebCloseLf, lf)
d.addBoth(self._cbPutMultipleNext, files, path)
return d
def _cbPutOpenFile(self, rf, lf):
numRequests = self.client.transport.conn.options['requests']
if self.useProgressBar:
lf = FileWrapper(lf)
dList = []
chunks = []
startTime = self.reactor.seconds()
for i in range(numRequests):
d = self._cbPutWrite(None, rf, lf, chunks, startTime)
if d:
dList.append(d)
dl = defer.DeferredList(dList, fireOnOneErrback=1)
dl.addCallback(self._cbPutDone, rf, lf)
return dl
def _cbPutWrite(self, ignored, rf, lf, chunks, startTime):
chunk = self._getNextChunk(chunks)
start, size = chunk
lf.seek(start)
data = lf.read(size)
if self.useProgressBar:
lf.total += len(data)
self._printProgressBar(lf, startTime)
if data:
d = rf.writeChunk(start, data)
d.addCallback(self._cbPutWrite, rf, lf, chunks, startTime)
return d
else:
return
def _cbPutDone(self, ignored, rf, lf):
lf.close()
rf.close()
if self.useProgressBar:
self.transport.write('\n')
return 'Transferred %s to %s' % (lf.name, rf.name)
def cmd_LCD(self, path):
os.chdir(path)
def cmd_LN(self, rest):
linkpath, rest = self._getFilename(rest)
targetpath, rest = self._getFilename(rest)
linkpath, targetpath = map(
lambda x: os.path.join(self.currentDirectory, x),
(linkpath, targetpath))
return self.client.makeLink(linkpath, targetpath).addCallback(_ignore)
def cmd_LS(self, rest):
# possible lines:
# ls current directory
# ls name_of_file that file
# ls name_of_directory that directory
# ls some_glob_string current directory, globbed for that string
options = []
rest = rest.split()
while rest and rest[0] and rest[0][0] == '-':
opts = rest.pop(0)[1:]
for o in opts:
if o == 'l':
options.append('verbose')
elif o == 'a':
options.append('all')
rest = ' '.join(rest)
path, rest = self._getFilename(rest)
if not path:
fullPath = self.currentDirectory + '/'
else:
fullPath = os.path.join(self.currentDirectory, path)
d = self._remoteGlob(fullPath)
d.addCallback(self._cbDisplayFiles, options)
return d
def _cbDisplayFiles(self, files, options):
files.sort()
if 'all' not in options:
files = [f for f in files if not f[0].startswith('.')]
if 'verbose' in options:
lines = [f[1] for f in files]
else:
lines = [f[0] for f in files]
if not lines:
return None
else:
return '\n'.join(lines)
def cmd_MKDIR(self, path):
path, rest = self._getFilename(path)
path = os.path.join(self.currentDirectory, path)
return self.client.makeDirectory(path, {}).addCallback(_ignore)
def cmd_RMDIR(self, path):
path, rest = self._getFilename(path)
path = os.path.join(self.currentDirectory, path)
return self.client.removeDirectory(path).addCallback(_ignore)
def cmd_LMKDIR(self, path):
os.system("mkdir %s" % path)
def cmd_RM(self, path):
path, rest = self._getFilename(path)
path = os.path.join(self.currentDirectory, path)
return self.client.removeFile(path).addCallback(_ignore)
def cmd_LLS(self, rest):
os.system("ls %s" % rest)
def cmd_RENAME(self, rest):
oldpath, rest = self._getFilename(rest)
newpath, rest = self._getFilename(rest)
oldpath, newpath = map (
lambda x: os.path.join(self.currentDirectory, x),
(oldpath, newpath))
return self.client.renameFile(oldpath, newpath).addCallback(_ignore)
def cmd_EXIT(self, ignored):
self.client.transport.loseConnection()
cmd_QUIT = cmd_EXIT
def cmd_VERSION(self, ignored):
return "SFTP version %i" % self.client.version
def cmd_HELP(self, ignored):
return """Available commands:
cd path Change remote directory to 'path'.
chgrp gid path Change gid of 'path' to 'gid'.
chmod mode path Change mode of 'path' to 'mode'.
chown uid path Change uid of 'path' to 'uid'.
exit Disconnect from the server.
get remote-path [local-path] Get remote file.
help Get a list of available commands.
lcd path Change local directory to 'path'.
lls [ls-options] [path] Display local directory listing.
lmkdir path Create local directory.
ln linkpath targetpath Symlink remote file.
lpwd Print the local working directory.
ls [-l] [path] Display remote directory listing.
mkdir path Create remote directory.
progress Toggle progress bar.
put local-path [remote-path] Put local file.
pwd Print the remote working directory.
quit Disconnect from the server.
rename oldpath newpath Rename remote file.
rmdir path Remove remote directory.
rm path Remove remote file.
version Print the SFTP version.
? Synonym for 'help'.
"""
def cmd_PWD(self, ignored):
return self.currentDirectory
def cmd_LPWD(self, ignored):
return os.getcwd()
def cmd_PROGRESS(self, ignored):
self.useProgressBar = not self.useProgressBar
return "%ssing progess bar." % (self.useProgressBar and "U" or "Not u")
def cmd_EXEC(self, rest):
"""
Run C{rest} using the user's shell (or /bin/sh if they do not have
one).
"""
shell = self._pwd.getpwnam(getpass.getuser())[6]
if not shell:
shell = '/bin/sh'
if rest:
cmds = ['-c', rest]
return utils.getProcessOutput(shell, cmds, errortoo=1)
else:
os.system(shell)
# accessory functions
def _remoteGlob(self, fullPath):
log.msg('looking up %s' % fullPath)
head, tail = os.path.split(fullPath)
if '*' in tail or '?' in tail:
glob = 1
else:
glob = 0
if tail and not glob: # could be file or directory
# try directory first
d = self.client.openDirectory(fullPath)
d.addCallback(self._cbOpenList, '')
d.addErrback(self._ebNotADirectory, head, tail)
else:
d = self.client.openDirectory(head)
d.addCallback(self._cbOpenList, tail)
return d
def _cbOpenList(self, directory, glob):
files = []
d = directory.read()
d.addBoth(self._cbReadFile, files, directory, glob)
return d
def _ebNotADirectory(self, reason, path, glob):
d = self.client.openDirectory(path)
d.addCallback(self._cbOpenList, glob)
return d
def _cbReadFile(self, files, l, directory, glob):
if not isinstance(files, failure.Failure):
if glob:
l.extend([f for f in files if fnmatch.fnmatch(f[0], glob)])
else:
l.extend(files)
d = directory.read()
d.addBoth(self._cbReadFile, l, directory, glob)
return d
else:
reason = files
reason.trap(EOFError)
directory.close()
return l
def _abbrevSize(self, size):
# from http://mail.python.org/pipermail/python-list/1999-December/018395.html
_abbrevs = [
(1<<50L, 'PB'),
(1<<40L, 'TB'),
(1<<30L, 'GB'),
(1<<20L, 'MB'),
(1<<10L, 'kB'),
(1, 'B')
]
for factor, suffix in _abbrevs:
if size > factor:
break
return '%.1f' % (size/factor) + suffix
def _abbrevTime(self, t):
if t > 3600: # 1 hour
hours = int(t / 3600)
t -= (3600 * hours)
mins = int(t / 60)
t -= (60 * mins)
return "%i:%02i:%02i" % (hours, mins, t)
else:
mins = int(t/60)
t -= (60 * mins)
return "%02i:%02i" % (mins, t)
def _printProgressBar(self, f, startTime):
"""
Update a console progress bar on this L{StdioClient}'s transport, based
on the difference between the start time of the operation and the
current time according to the reactor, and appropriate to the size of
the console window.
@param f: a wrapper around the file which is being written or read
@type f: L{FileWrapper}
@param startTime: The time at which the operation being tracked began.
@type startTime: C{float}
"""
diff = self.reactor.seconds() - startTime
total = f.total
try:
winSize = struct.unpack('4H',
fcntl.ioctl(0, tty.TIOCGWINSZ, '12345679'))
except IOError:
winSize = [None, 80]
if diff == 0.0:
speed = 0.0
else:
speed = total / diff
if speed:
timeLeft = (f.size - total) / speed
else:
timeLeft = 0
front = f.name
back = '%3i%% %s %sps %s ' % ((total / f.size) * 100,
self._abbrevSize(total),
self._abbrevSize(speed),
self._abbrevTime(timeLeft))
spaces = (winSize[1] - (len(front) + len(back) + 1)) * ' '
self.transport.write('\r%s%s%s' % (front, spaces, back))
def _getFilename(self, line):
line.lstrip()
if not line:
return None, ''
if line[0] in '\'"':
ret = []
line = list(line)
try:
for i in range(1,len(line)):
c = line[i]
if c == line[0]:
return ''.join(ret), ''.join(line[i+1:]).lstrip()
elif c == '\\': # quoted character
del line[i]
if line[i] not in '\'"\\':
raise IndexError, "bad quote: \\%s" % line[i]
ret.append(line[i])
else:
ret.append(line[i])
except IndexError:
raise IndexError, "unterminated quote"
ret = line.split(None, 1)
if len(ret) == 1:
return ret[0], ''
else:
return ret
StdioClient.__dict__['cmd_?'] = StdioClient.cmd_HELP
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
self.openChannel(SSHSession())
class SSHSession(channel.SSHChannel):
name = 'session'
def channelOpen(self, foo):
log.msg('session %s open' % self.id)
if self.conn.options['subsystem'].startswith('/'):
request = 'exec'
else:
request = 'subsystem'
d = self.conn.sendRequest(self, request, \
common.NS(self.conn.options['subsystem']), wantReply=1)
d.addCallback(self._cbSubsystem)
d.addErrback(_ebExit)
def _cbSubsystem(self, result):
self.client = filetransfer.FileTransferClient()
self.client.makeConnection(self)
self.dataReceived = self.client.dataReceived
f = None
if self.conn.options['batchfile']:
fn = self.conn.options['batchfile']
if fn != '-':
f = file(fn)
self.stdio = stdio.StandardIO(StdioClient(self.client, f))
def extReceived(self, t, data):
if t==connection.EXTENDED_DATA_STDERR:
log.msg('got %s stderr data' % len(data))
sys.stderr.write(data)
sys.stderr.flush()
def eofReceived(self):
log.msg('got eof')
self.stdio.loseWriteConnection()
def closeReceived(self):
log.msg('remote side closed %s' % self)
self.conn.sendClose(self)
def closed(self):
try:
reactor.stop()
except:
pass
def stopWriting(self):
self.stdio.pauseProducing()
def startWriting(self):
self.stdio.resumeProducing()
if __name__ == '__main__':
run()

View File

@ -0,0 +1,223 @@
# -*- test-case-name: twisted.conch.test.test_ckeygen -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the `ckeygen` command.
"""
import sys, os, getpass, socket
if getpass.getpass == getpass.unix_getpass:
try:
import termios # hack around broken termios
termios.tcgetattr, termios.tcsetattr
except (ImportError, AttributeError):
sys.modules['termios'] = None
reload(getpass)
from twisted.conch.ssh import keys
from twisted.python import failure, filepath, log, usage, randbytes
class GeneralOptions(usage.Options):
synopsis = """Usage: ckeygen [options]
"""
longdesc = "ckeygen manipulates public/private keys in various ways."
optParameters = [['bits', 'b', 1024, 'Number of bits in the key to create.'],
['filename', 'f', None, 'Filename of the key file.'],
['type', 't', None, 'Specify type of key to create.'],
['comment', 'C', None, 'Provide new comment.'],
['newpass', 'N', None, 'Provide new passphrase.'],
['pass', 'P', None, 'Provide old passphrase.']]
optFlags = [['fingerprint', 'l', 'Show fingerprint of key file.'],
['changepass', 'p', 'Change passphrase of private key file.'],
['quiet', 'q', 'Quiet.'],
['no-passphrase', None, "Create the key with no passphrase."],
['showpub', 'y', 'Read private key file and print public key.']]
compData = usage.Completions(
optActions={"type": usage.CompleteList(["rsa", "dsa"])})
def run():
options = GeneralOptions()
try:
options.parseOptions(sys.argv[1:])
except usage.UsageError, u:
print 'ERROR: %s' % u
options.opt_help()
sys.exit(1)
log.discardLogs()
log.deferr = handleError # HACK
if options['type']:
if options['type'] == 'rsa':
generateRSAkey(options)
elif options['type'] == 'dsa':
generateDSAkey(options)
else:
sys.exit('Key type was %s, must be one of: rsa, dsa' % options['type'])
elif options['fingerprint']:
printFingerprint(options)
elif options['changepass']:
changePassPhrase(options)
elif options['showpub']:
displayPublicKey(options)
else:
options.opt_help()
sys.exit(1)
def handleError():
global exitStatus
exitStatus = 2
log.err(failure.Failure())
raise
def generateRSAkey(options):
from Crypto.PublicKey import RSA
print 'Generating public/private rsa key pair.'
key = RSA.generate(int(options['bits']), randbytes.secureRandom)
_saveKey(key, options)
def generateDSAkey(options):
from Crypto.PublicKey import DSA
print 'Generating public/private dsa key pair.'
key = DSA.generate(int(options['bits']), randbytes.secureRandom)
_saveKey(key, options)
def printFingerprint(options):
if not options['filename']:
filename = os.path.expanduser('~/.ssh/id_rsa')
options['filename'] = raw_input('Enter file in which the key is (%s): ' % filename)
if os.path.exists(options['filename']+'.pub'):
options['filename'] += '.pub'
try:
key = keys.Key.fromFile(options['filename'])
obj = key.keyObject
print '%s %s %s' % (
obj.size() + 1,
key.fingerprint(),
os.path.basename(options['filename']))
except:
sys.exit('bad key')
def changePassPhrase(options):
if not options['filename']:
filename = os.path.expanduser('~/.ssh/id_rsa')
options['filename'] = raw_input(
'Enter file in which the key is (%s): ' % filename)
try:
key = keys.Key.fromFile(options['filename']).keyObject
except keys.EncryptedKeyError as e:
# Raised if password not supplied for an encrypted key
if not options.get('pass'):
options['pass'] = getpass.getpass('Enter old passphrase: ')
try:
key = keys.Key.fromFile(
options['filename'], passphrase=options['pass']).keyObject
except keys.BadKeyError:
sys.exit('Could not change passphrase: old passphrase error')
except keys.EncryptedKeyError as e:
sys.exit('Could not change passphrase: %s' % (e,))
except keys.BadKeyError as e:
sys.exit('Could not change passphrase: %s' % (e,))
if not options.get('newpass'):
while 1:
p1 = getpass.getpass(
'Enter new passphrase (empty for no passphrase): ')
p2 = getpass.getpass('Enter same passphrase again: ')
if p1 == p2:
break
print 'Passphrases do not match. Try again.'
options['newpass'] = p1
try:
newkeydata = keys.Key(key).toString('openssh',
extra=options['newpass'])
except Exception as e:
sys.exit('Could not change passphrase: %s' % (e,))
try:
keys.Key.fromString(newkeydata, passphrase=options['newpass'])
except (keys.EncryptedKeyError, keys.BadKeyError) as e:
sys.exit('Could not change passphrase: %s' % (e,))
fd = open(options['filename'], 'w')
fd.write(newkeydata)
fd.close()
print 'Your identification has been saved with the new passphrase.'
def displayPublicKey(options):
if not options['filename']:
filename = os.path.expanduser('~/.ssh/id_rsa')
options['filename'] = raw_input('Enter file in which the key is (%s): ' % filename)
try:
key = keys.Key.fromFile(options['filename']).keyObject
except keys.EncryptedKeyError:
if not options.get('pass'):
options['pass'] = getpass.getpass('Enter passphrase: ')
key = keys.Key.fromFile(
options['filename'], passphrase = options['pass']).keyObject
print keys.Key(key).public().toString('openssh')
def _saveKey(key, options):
if not options['filename']:
kind = keys.objectType(key)
kind = {'ssh-rsa':'rsa','ssh-dss':'dsa'}[kind]
filename = os.path.expanduser('~/.ssh/id_%s'%kind)
options['filename'] = raw_input('Enter file in which to save the key (%s): '%filename).strip() or filename
if os.path.exists(options['filename']):
print '%s already exists.' % options['filename']
yn = raw_input('Overwrite (y/n)? ')
if yn[0].lower() != 'y':
sys.exit()
if options.get('no-passphrase'):
options['pass'] = b''
elif not options['pass']:
while 1:
p1 = getpass.getpass('Enter passphrase (empty for no passphrase): ')
p2 = getpass.getpass('Enter same passphrase again: ')
if p1 == p2:
break
print 'Passphrases do not match. Try again.'
options['pass'] = p1
keyObj = keys.Key(key)
comment = '%s@%s' % (getpass.getuser(), socket.gethostname())
filepath.FilePath(options['filename']).setContent(
keyObj.toString('openssh', options['pass']))
os.chmod(options['filename'], 33152)
filepath.FilePath(options['filename'] + '.pub').setContent(
keyObj.public().toString('openssh', comment))
print 'Your identification has been saved in %s' % options['filename']
print 'Your public key has been saved in %s.pub' % options['filename']
print 'The key fingerprint is:'
print keyObj.fingerprint()
if __name__ == '__main__':
run()

View File

@ -0,0 +1,508 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
# $Id: conch.py,v 1.65 2004/03/11 00:29:14 z3p Exp $
#""" Implementation module for the `conch` command.
#"""
from twisted.conch.client import connect, default, options
from twisted.conch.error import ConchError
from twisted.conch.ssh import connection, common
from twisted.conch.ssh import session, forwarding, channel
from twisted.internet import reactor, stdio, task
from twisted.python import log, usage
import os, sys, getpass, struct, tty, fcntl, signal
class ClientOptions(options.ConchOptions):
synopsis = """Usage: conch [options] host [command]
"""
longdesc = ("conch is a SSHv2 client that allows logging into a remote "
"machine and executing commands.")
optParameters = [['escape', 'e', '~'],
['localforward', 'L', None, 'listen-port:host:port Forward local port to remote address'],
['remoteforward', 'R', None, 'listen-port:host:port Forward remote port to local address'],
]
optFlags = [['null', 'n', 'Redirect input from /dev/null.'],
['fork', 'f', 'Fork to background after authentication.'],
['tty', 't', 'Tty; allocate a tty even if command is given.'],
['notty', 'T', 'Do not allocate a tty.'],
['noshell', 'N', 'Do not execute a shell or command.'],
['subsystem', 's', 'Invoke command (mandatory) as SSH2 subsystem.'],
]
compData = usage.Completions(
mutuallyExclusive=[("tty", "notty")],
optActions={
"localforward": usage.Completer(descr="listen-port:host:port"),
"remoteforward": usage.Completer(descr="listen-port:host:port")},
extraActions=[usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True)]
)
localForwards = []
remoteForwards = []
def opt_escape(self, esc):
"Set escape character; ``none'' = disable"
if esc == 'none':
self['escape'] = None
elif esc[0] == '^' and len(esc) == 2:
self['escape'] = chr(ord(esc[1])-64)
elif len(esc) == 1:
self['escape'] = esc
else:
sys.exit("Bad escape character '%s'." % esc)
def opt_localforward(self, f):
"Forward local port to remote address (lport:host:port)"
localPort, remoteHost, remotePort = f.split(':') # doesn't do v6 yet
localPort = int(localPort)
remotePort = int(remotePort)
self.localForwards.append((localPort, (remoteHost, remotePort)))
def opt_remoteforward(self, f):
"""Forward remote port to local address (rport:host:port)"""
remotePort, connHost, connPort = f.split(':') # doesn't do v6 yet
remotePort = int(remotePort)
connPort = int(connPort)
self.remoteForwards.append((remotePort, (connHost, connPort)))
def parseArgs(self, host, *command):
self['host'] = host
self['command'] = ' '.join(command)
# Rest of code in "run"
options = None
conn = None
exitStatus = 0
old = None
_inRawMode = 0
_savedRawMode = None
def run():
global options, old
args = sys.argv[1:]
if '-l' in args: # cvs is an idiot
i = args.index('-l')
args = args[i:i+2]+args
del args[i+2:i+4]
for arg in args[:]:
try:
i = args.index(arg)
if arg[:2] == '-o' and args[i+1][0]!='-':
args[i:i+2] = [] # suck on it scp
except ValueError:
pass
options = ClientOptions()
try:
options.parseOptions(args)
except usage.UsageError, u:
print 'ERROR: %s' % u
options.opt_help()
sys.exit(1)
if options['log']:
if options['logfile']:
if options['logfile'] == '-':
f = sys.stdout
else:
f = file(options['logfile'], 'a+')
else:
f = sys.stderr
realout = sys.stdout
log.startLogging(f)
sys.stdout = realout
else:
log.discardLogs()
doConnect()
fd = sys.stdin.fileno()
try:
old = tty.tcgetattr(fd)
except:
old = None
try:
oldUSR1 = signal.signal(signal.SIGUSR1, lambda *a: reactor.callLater(0, reConnect))
except:
oldUSR1 = None
try:
reactor.run()
finally:
if old:
tty.tcsetattr(fd, tty.TCSANOW, old)
if oldUSR1:
signal.signal(signal.SIGUSR1, oldUSR1)
if (options['command'] and options['tty']) or not options['notty']:
signal.signal(signal.SIGWINCH, signal.SIG_DFL)
if sys.stdout.isatty() and not options['command']:
print 'Connection to %s closed.' % options['host']
sys.exit(exitStatus)
def handleError():
from twisted.python import failure
global exitStatus
exitStatus = 2
reactor.callLater(0.01, _stopReactor)
log.err(failure.Failure())
raise
def _stopReactor():
try:
reactor.stop()
except: pass
def doConnect():
# log.deferr = handleError # HACK
if '@' in options['host']:
options['user'], options['host'] = options['host'].split('@',1)
if not options.identitys:
options.identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
host = options['host']
if not options['user']:
options['user'] = getpass.getuser()
if not options['port']:
options['port'] = 22
else:
options['port'] = int(options['port'])
host = options['host']
port = options['port']
vhk = default.verifyHostKey
uao = default.SSHUserAuthClient(options['user'], options, SSHConnection())
connect.connect(host, port, options, vhk, uao).addErrback(_ebExit)
def _ebExit(f):
global exitStatus
exitStatus = "conch: exiting with error %s" % f
reactor.callLater(0.1, _stopReactor)
def onConnect():
# if keyAgent and options['agent']:
# cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal, conn)
# cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
if hasattr(conn.transport, 'sendIgnore'):
_KeepAlive(conn)
if options.localForwards:
for localPort, hostport in options.localForwards:
s = reactor.listenTCP(localPort,
forwarding.SSHListenForwardingFactory(conn,
hostport,
SSHListenClientForwardingChannel))
conn.localForwards.append(s)
if options.remoteForwards:
for remotePort, hostport in options.remoteForwards:
log.msg('asking for remote forwarding for %s:%s' %
(remotePort, hostport))
conn.requestRemoteForwarding(remotePort, hostport)
reactor.addSystemEventTrigger('before', 'shutdown', beforeShutdown)
if not options['noshell'] or options['agent']:
conn.openChannel(SSHSession())
if options['fork']:
if os.fork():
os._exit(0)
os.setsid()
for i in range(3):
try:
os.close(i)
except OSError, e:
import errno
if e.errno != errno.EBADF:
raise
def reConnect():
beforeShutdown()
conn.transport.transport.loseConnection()
def beforeShutdown():
remoteForwards = options.remoteForwards
for remotePort, hostport in remoteForwards:
log.msg('cancelling %s:%s' % (remotePort, hostport))
conn.cancelRemoteForwarding(remotePort)
def stopConnection():
if not options['reconnect']:
reactor.callLater(0.1, _stopReactor)
class _KeepAlive:
def __init__(self, conn):
self.conn = conn
self.globalTimeout = None
self.lc = task.LoopingCall(self.sendGlobal)
self.lc.start(300)
def sendGlobal(self):
d = self.conn.sendGlobalRequest("conch-keep-alive@twistedmatrix.com",
"", wantReply = 1)
d.addBoth(self._cbGlobal)
self.globalTimeout = reactor.callLater(30, self._ebGlobal)
def _cbGlobal(self, res):
if self.globalTimeout:
self.globalTimeout.cancel()
self.globalTimeout = None
def _ebGlobal(self):
if self.globalTimeout:
self.globalTimeout = None
self.conn.transport.loseConnection()
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
global conn
conn = self
self.localForwards = []
self.remoteForwards = {}
if not isinstance(self, connection.SSHConnection):
# make these fall through
del self.__class__.requestRemoteForwarding
del self.__class__.cancelRemoteForwarding
onConnect()
def serviceStopped(self):
lf = self.localForwards
self.localForwards = []
for s in lf:
s.loseConnection()
stopConnection()
def requestRemoteForwarding(self, remotePort, hostport):
data = forwarding.packGlobal_tcpip_forward(('0.0.0.0', remotePort))
d = self.sendGlobalRequest('tcpip-forward', data,
wantReply=1)
log.msg('requesting remote forwarding %s:%s' %(remotePort, hostport))
d.addCallback(self._cbRemoteForwarding, remotePort, hostport)
d.addErrback(self._ebRemoteForwarding, remotePort, hostport)
def _cbRemoteForwarding(self, result, remotePort, hostport):
log.msg('accepted remote forwarding %s:%s' % (remotePort, hostport))
self.remoteForwards[remotePort] = hostport
log.msg(repr(self.remoteForwards))
def _ebRemoteForwarding(self, f, remotePort, hostport):
log.msg('remote forwarding %s:%s failed' % (remotePort, hostport))
log.msg(f)
def cancelRemoteForwarding(self, remotePort):
data = forwarding.packGlobal_tcpip_forward(('0.0.0.0', remotePort))
self.sendGlobalRequest('cancel-tcpip-forward', data)
log.msg('cancelling remote forwarding %s' % remotePort)
try:
del self.remoteForwards[remotePort]
except:
pass
log.msg(repr(self.remoteForwards))
def channel_forwarded_tcpip(self, windowSize, maxPacket, data):
log.msg('%s %s' % ('FTCP', repr(data)))
remoteHP, origHP = forwarding.unpackOpen_forwarded_tcpip(data)
log.msg(self.remoteForwards)
log.msg(remoteHP)
if self.remoteForwards.has_key(remoteHP[1]):
connectHP = self.remoteForwards[remoteHP[1]]
log.msg('connect forwarding %s' % (connectHP,))
return SSHConnectForwardingChannel(connectHP,
remoteWindow = windowSize,
remoteMaxPacket = maxPacket,
conn = self)
else:
raise ConchError(connection.OPEN_CONNECT_FAILED, "don't know about that port")
# def channel_auth_agent_openssh_com(self, windowSize, maxPacket, data):
# if options['agent'] and keyAgent:
# return agent.SSHAgentForwardingChannel(remoteWindow = windowSize,
# remoteMaxPacket = maxPacket,
# conn = self)
# else:
# return connection.OPEN_CONNECT_FAILED, "don't have an agent"
def channelClosed(self, channel):
log.msg('connection closing %s' % channel)
log.msg(self.channels)
if len(self.channels) == 1: # just us left
log.msg('stopping connection')
stopConnection()
else:
# because of the unix thing
self.__class__.__bases__[0].channelClosed(self, channel)
class SSHSession(channel.SSHChannel):
name = 'session'
def channelOpen(self, foo):
log.msg('session %s open' % self.id)
if options['agent']:
d = self.conn.sendRequest(self, 'auth-agent-req@openssh.com', '', wantReply=1)
d.addBoth(lambda x:log.msg(x))
if options['noshell']: return
if (options['command'] and options['tty']) or not options['notty']:
_enterRawMode()
c = session.SSHSessionClient()
if options['escape'] and not options['notty']:
self.escapeMode = 1
c.dataReceived = self.handleInput
else:
c.dataReceived = self.write
c.connectionLost = lambda x=None,s=self:s.sendEOF()
self.stdio = stdio.StandardIO(c)
fd = 0
if options['subsystem']:
self.conn.sendRequest(self, 'subsystem', \
common.NS(options['command']))
elif options['command']:
if options['tty']:
term = os.environ['TERM']
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, '')
self.conn.sendRequest(self, 'pty-req', ptyReqData)
signal.signal(signal.SIGWINCH, self._windowResized)
self.conn.sendRequest(self, 'exec', \
common.NS(options['command']))
else:
if not options['notty']:
term = os.environ['TERM']
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, '')
self.conn.sendRequest(self, 'pty-req', ptyReqData)
signal.signal(signal.SIGWINCH, self._windowResized)
self.conn.sendRequest(self, 'shell', '')
#if hasattr(conn.transport, 'transport'):
# conn.transport.transport.setTcpNoDelay(1)
def handleInput(self, char):
#log.msg('handling %s' % repr(char))
if char in ('\n', '\r'):
self.escapeMode = 1
self.write(char)
elif self.escapeMode == 1 and char == options['escape']:
self.escapeMode = 2
elif self.escapeMode == 2:
self.escapeMode = 1 # so we can chain escapes together
if char == '.': # disconnect
log.msg('disconnecting from escape')
stopConnection()
return
elif char == '\x1a': # ^Z, suspend
def _():
_leaveRawMode()
sys.stdout.flush()
sys.stdin.flush()
os.kill(os.getpid(), signal.SIGTSTP)
_enterRawMode()
reactor.callLater(0, _)
return
elif char == 'R': # rekey connection
log.msg('rekeying connection')
self.conn.transport.sendKexInit()
return
elif char == '#': # display connections
self.stdio.write('\r\nThe following connections are open:\r\n')
channels = self.conn.channels.keys()
channels.sort()
for channelId in channels:
self.stdio.write(' #%i %s\r\n' % (channelId, str(self.conn.channels[channelId])))
return
self.write('~' + char)
else:
self.escapeMode = 0
self.write(char)
def dataReceived(self, data):
self.stdio.write(data)
def extReceived(self, t, data):
if t==connection.EXTENDED_DATA_STDERR:
log.msg('got %s stderr data' % len(data))
sys.stderr.write(data)
def eofReceived(self):
log.msg('got eof')
self.stdio.loseWriteConnection()
def closeReceived(self):
log.msg('remote side closed %s' % self)
self.conn.sendClose(self)
def closed(self):
global old
log.msg('closed %s' % self)
log.msg(repr(self.conn.channels))
def request_exit_status(self, data):
global exitStatus
exitStatus = int(struct.unpack('>L', data)[0])
log.msg('exit status: %s' % exitStatus)
def sendEOF(self):
self.conn.sendEOF(self)
def stopWriting(self):
self.stdio.pauseProducing()
def startWriting(self):
self.stdio.resumeProducing()
def _windowResized(self, *args):
winsz = fcntl.ioctl(0, tty.TIOCGWINSZ, '12345678')
winSize = struct.unpack('4H', winsz)
newSize = winSize[1], winSize[0], winSize[2], winSize[3]
self.conn.sendRequest(self, 'window-change', struct.pack('!4L', *newSize))
class SSHListenClientForwardingChannel(forwarding.SSHListenClientForwardingChannel): pass
class SSHConnectForwardingChannel(forwarding.SSHConnectForwardingChannel): pass
def _leaveRawMode():
global _inRawMode
if not _inRawMode:
return
fd = sys.stdin.fileno()
tty.tcsetattr(fd, tty.TCSANOW, _savedRawMode)
_inRawMode = 0
def _enterRawMode():
global _inRawMode, _savedRawMode
if _inRawMode:
return
fd = sys.stdin.fileno()
try:
old = tty.tcgetattr(fd)
new = old[:]
except:
log.msg('not a typewriter!')
else:
# iflage
new[0] = new[0] | tty.IGNPAR
new[0] = new[0] & ~(tty.ISTRIP | tty.INLCR | tty.IGNCR | tty.ICRNL |
tty.IXON | tty.IXANY | tty.IXOFF)
if hasattr(tty, 'IUCLC'):
new[0] = new[0] & ~tty.IUCLC
# lflag
new[3] = new[3] & ~(tty.ISIG | tty.ICANON | tty.ECHO | tty.ECHO |
tty.ECHOE | tty.ECHOK | tty.ECHONL)
if hasattr(tty, 'IEXTEN'):
new[3] = new[3] & ~tty.IEXTEN
#oflag
new[1] = new[1] & ~tty.OPOST
new[6][tty.VMIN] = 1
new[6][tty.VTIME] = 0
_savedRawMode = old
tty.tcsetattr(fd, tty.TCSANOW, new)
#tty.setraw(fd)
_inRawMode = 1
if __name__ == '__main__':
run()

View File

@ -0,0 +1,573 @@
# -*- test-case-name: twisted.conch.test.test_scripts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the `tkconch` command.
"""
import Tkinter, tkFileDialog, tkMessageBox
from twisted.conch import error
from twisted.conch.ui import tkvt100
from twisted.conch.ssh import transport, userauth, connection, common, keys
from twisted.conch.ssh import session, forwarding, channel
from twisted.conch.client.default import isInKnownHosts
from twisted.internet import reactor, defer, protocol, tksupport
from twisted.python import usage, log
import os, sys, getpass, struct, base64, signal
class TkConchMenu(Tkinter.Frame):
def __init__(self, *args, **params):
## Standard heading: initialization
apply(Tkinter.Frame.__init__, (self,) + args, params)
self.master.title('TkConch')
self.localRemoteVar = Tkinter.StringVar()
self.localRemoteVar.set('local')
Tkinter.Label(self, anchor='w', justify='left', text='Hostname').grid(column=1, row=1, sticky='w')
self.host = Tkinter.Entry(self)
self.host.grid(column=2, columnspan=2, row=1, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Port').grid(column=1, row=2, sticky='w')
self.port = Tkinter.Entry(self)
self.port.grid(column=2, columnspan=2, row=2, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Username').grid(column=1, row=3, sticky='w')
self.user = Tkinter.Entry(self)
self.user.grid(column=2, columnspan=2, row=3, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Command').grid(column=1, row=4, sticky='w')
self.command = Tkinter.Entry(self)
self.command.grid(column=2, columnspan=2, row=4, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Identity').grid(column=1, row=5, sticky='w')
self.identity = Tkinter.Entry(self)
self.identity.grid(column=2, row=5, sticky='nesw')
Tkinter.Button(self, command=self.getIdentityFile, text='Browse').grid(column=3, row=5, sticky='nesw')
Tkinter.Label(self, text='Port Forwarding').grid(column=1, row=6, sticky='w')
self.forwards = Tkinter.Listbox(self, height=0, width=0)
self.forwards.grid(column=2, columnspan=2, row=6, sticky='nesw')
Tkinter.Button(self, text='Add', command=self.addForward).grid(column=1, row=7)
Tkinter.Button(self, text='Remove', command=self.removeForward).grid(column=1, row=8)
self.forwardPort = Tkinter.Entry(self)
self.forwardPort.grid(column=2, row=7, sticky='nesw')
Tkinter.Label(self, text='Port').grid(column=3, row=7, sticky='nesw')
self.forwardHost = Tkinter.Entry(self)
self.forwardHost.grid(column=2, row=8, sticky='nesw')
Tkinter.Label(self, text='Host').grid(column=3, row=8, sticky='nesw')
self.localForward = Tkinter.Radiobutton(self, text='Local', variable=self.localRemoteVar, value='local')
self.localForward.grid(column=2, row=9)
self.remoteForward = Tkinter.Radiobutton(self, text='Remote', variable=self.localRemoteVar, value='remote')
self.remoteForward.grid(column=3, row=9)
Tkinter.Label(self, text='Advanced Options').grid(column=1, columnspan=3, row=10, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Cipher').grid(column=1, row=11, sticky='w')
self.cipher = Tkinter.Entry(self, name='cipher')
self.cipher.grid(column=2, columnspan=2, row=11, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='MAC').grid(column=1, row=12, sticky='w')
self.mac = Tkinter.Entry(self, name='mac')
self.mac.grid(column=2, columnspan=2, row=12, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Escape Char').grid(column=1, row=13, sticky='w')
self.escape = Tkinter.Entry(self, name='escape')
self.escape.grid(column=2, columnspan=2, row=13, sticky='nesw')
Tkinter.Button(self, text='Connect!', command=self.doConnect).grid(column=1, columnspan=3, row=14, sticky='nesw')
# Resize behavior(s)
self.grid_rowconfigure(6, weight=1, minsize=64)
self.grid_columnconfigure(2, weight=1, minsize=2)
self.master.protocol("WM_DELETE_WINDOW", sys.exit)
def getIdentityFile(self):
r = tkFileDialog.askopenfilename()
if r:
self.identity.delete(0, Tkinter.END)
self.identity.insert(Tkinter.END, r)
def addForward(self):
port = self.forwardPort.get()
self.forwardPort.delete(0, Tkinter.END)
host = self.forwardHost.get()
self.forwardHost.delete(0, Tkinter.END)
if self.localRemoteVar.get() == 'local':
self.forwards.insert(Tkinter.END, 'L:%s:%s' % (port, host))
else:
self.forwards.insert(Tkinter.END, 'R:%s:%s' % (port, host))
def removeForward(self):
cur = self.forwards.curselection()
if cur:
self.forwards.remove(cur[0])
def doConnect(self):
finished = 1
options['host'] = self.host.get()
options['port'] = self.port.get()
options['user'] = self.user.get()
options['command'] = self.command.get()
cipher = self.cipher.get()
mac = self.mac.get()
escape = self.escape.get()
if cipher:
if cipher in SSHClientTransport.supportedCiphers:
SSHClientTransport.supportedCiphers = [cipher]
else:
tkMessageBox.showerror('TkConch', 'Bad cipher.')
finished = 0
if mac:
if mac in SSHClientTransport.supportedMACs:
SSHClientTransport.supportedMACs = [mac]
elif finished:
tkMessageBox.showerror('TkConch', 'Bad MAC.')
finished = 0
if escape:
if escape == 'none':
options['escape'] = None
elif escape[0] == '^' and len(escape) == 2:
options['escape'] = chr(ord(escape[1])-64)
elif len(escape) == 1:
options['escape'] = escape
elif finished:
tkMessageBox.showerror('TkConch', "Bad escape character '%s'." % escape)
finished = 0
if self.identity.get():
options.identitys.append(self.identity.get())
for line in self.forwards.get(0,Tkinter.END):
if line[0]=='L':
options.opt_localforward(line[2:])
else:
options.opt_remoteforward(line[2:])
if '@' in options['host']:
options['user'], options['host'] = options['host'].split('@',1)
if (not options['host'] or not options['user']) and finished:
tkMessageBox.showerror('TkConch', 'Missing host or username.')
finished = 0
if finished:
self.master.quit()
self.master.destroy()
if options['log']:
realout = sys.stdout
log.startLogging(sys.stderr)
sys.stdout = realout
else:
log.discardLogs()
log.deferr = handleError # HACK
if not options.identitys:
options.identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
host = options['host']
port = int(options['port'] or 22)
log.msg((host,port))
reactor.connectTCP(host, port, SSHClientFactory())
frame.master.deiconify()
frame.master.title('%s@%s - TkConch' % (options['user'], options['host']))
else:
self.focus()
class GeneralOptions(usage.Options):
synopsis = """Usage: tkconch [options] host [command]
"""
optParameters = [['user', 'l', None, 'Log in using this user name.'],
['identity', 'i', '~/.ssh/identity', 'Identity for public key authentication'],
['escape', 'e', '~', "Set escape character; ``none'' = disable"],
['cipher', 'c', None, 'Select encryption algorithm.'],
['macs', 'm', None, 'Specify MAC algorithms for protocol version 2.'],
['port', 'p', None, 'Connect to this port. Server must be on the same port.'],
['localforward', 'L', None, 'listen-port:host:port Forward local port to remote address'],
['remoteforward', 'R', None, 'listen-port:host:port Forward remote port to local address'],
]
optFlags = [['tty', 't', 'Tty; allocate a tty even if command is given.'],
['notty', 'T', 'Do not allocate a tty.'],
['version', 'V', 'Display version number only.'],
['compress', 'C', 'Enable compression.'],
['noshell', 'N', 'Do not execute a shell or command.'],
['subsystem', 's', 'Invoke command (mandatory) as SSH2 subsystem.'],
['log', 'v', 'Log to stderr'],
['ansilog', 'a', 'Print the received data to stdout']]
_ciphers = transport.SSHClientTransport.supportedCiphers
_macs = transport.SSHClientTransport.supportedMACs
compData = usage.Completions(
mutuallyExclusive=[("tty", "notty")],
optActions={
"cipher": usage.CompleteList(_ciphers),
"macs": usage.CompleteList(_macs),
"localforward": usage.Completer(descr="listen-port:host:port"),
"remoteforward": usage.Completer(descr="listen-port:host:port")},
extraActions=[usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True)]
)
identitys = []
localForwards = []
remoteForwards = []
def opt_identity(self, i):
self.identitys.append(i)
def opt_localforward(self, f):
localPort, remoteHost, remotePort = f.split(':') # doesn't do v6 yet
localPort = int(localPort)
remotePort = int(remotePort)
self.localForwards.append((localPort, (remoteHost, remotePort)))
def opt_remoteforward(self, f):
remotePort, connHost, connPort = f.split(':') # doesn't do v6 yet
remotePort = int(remotePort)
connPort = int(connPort)
self.remoteForwards.append((remotePort, (connHost, connPort)))
def opt_compress(self):
SSHClientTransport.supportedCompressions[0:1] = ['zlib']
def parseArgs(self, *args):
if args:
self['host'] = args[0]
self['command'] = ' '.join(args[1:])
else:
self['host'] = ''
self['command'] = ''
# Rest of code in "run"
options = None
menu = None
exitStatus = 0
frame = None
def deferredAskFrame(question, echo):
if frame.callback:
raise ValueError("can't ask 2 questions at once!")
d = defer.Deferred()
resp = []
def gotChar(ch, resp=resp):
if not ch: return
if ch=='\x03': # C-c
reactor.stop()
if ch=='\r':
frame.write('\r\n')
stresp = ''.join(resp)
del resp
frame.callback = None
d.callback(stresp)
return
elif 32 <= ord(ch) < 127:
resp.append(ch)
if echo:
frame.write(ch)
elif ord(ch) == 8 and resp: # BS
if echo: frame.write('\x08 \x08')
resp.pop()
frame.callback = gotChar
frame.write(question)
frame.canvas.focus_force()
return d
def run():
global menu, options, frame
args = sys.argv[1:]
if '-l' in args: # cvs is an idiot
i = args.index('-l')
args = args[i:i+2]+args
del args[i+2:i+4]
for arg in args[:]:
try:
i = args.index(arg)
if arg[:2] == '-o' and args[i+1][0]!='-':
args[i:i+2] = [] # suck on it scp
except ValueError:
pass
root = Tkinter.Tk()
root.withdraw()
top = Tkinter.Toplevel()
menu = TkConchMenu(top)
menu.pack(side=Tkinter.TOP, fill=Tkinter.BOTH, expand=1)
options = GeneralOptions()
try:
options.parseOptions(args)
except usage.UsageError, u:
print 'ERROR: %s' % u
options.opt_help()
sys.exit(1)
for k,v in options.items():
if v and hasattr(menu, k):
getattr(menu,k).insert(Tkinter.END, v)
for (p, (rh, rp)) in options.localForwards:
menu.forwards.insert(Tkinter.END, 'L:%s:%s:%s' % (p, rh, rp))
options.localForwards = []
for (p, (rh, rp)) in options.remoteForwards:
menu.forwards.insert(Tkinter.END, 'R:%s:%s:%s' % (p, rh, rp))
options.remoteForwards = []
frame = tkvt100.VT100Frame(root, callback=None)
root.geometry('%dx%d'%(tkvt100.fontWidth*frame.width+3, tkvt100.fontHeight*frame.height+3))
frame.pack(side = Tkinter.TOP)
tksupport.install(root)
root.withdraw()
if (options['host'] and options['user']) or '@' in options['host']:
menu.doConnect()
else:
top.mainloop()
reactor.run()
sys.exit(exitStatus)
def handleError():
from twisted.python import failure
global exitStatus
exitStatus = 2
log.err(failure.Failure())
reactor.stop()
raise
class SSHClientFactory(protocol.ClientFactory):
noisy = 1
def stopFactory(self):
reactor.stop()
def buildProtocol(self, addr):
return SSHClientTransport()
def clientConnectionFailed(self, connector, reason):
tkMessageBox.showwarning('TkConch','Connection Failed, Reason:\n %s: %s' % (reason.type, reason.value))
class SSHClientTransport(transport.SSHClientTransport):
def receiveError(self, code, desc):
global exitStatus
exitStatus = 'conch:\tRemote side disconnected with error code %i\nconch:\treason: %s' % (code, desc)
def sendDisconnect(self, code, reason):
global exitStatus
exitStatus = 'conch:\tSending disconnect with error code %i\nconch:\treason: %s' % (code, reason)
transport.SSHClientTransport.sendDisconnect(self, code, reason)
def receiveDebug(self, alwaysDisplay, message, lang):
global options
if alwaysDisplay or options['log']:
log.msg('Received Debug Message: %s' % message)
def verifyHostKey(self, pubKey, fingerprint):
#d = defer.Deferred()
#d.addCallback(lambda x:defer.succeed(1))
#d.callback(2)
#return d
goodKey = isInKnownHosts(options['host'], pubKey, {'known-hosts': None})
if goodKey == 1: # good key
return defer.succeed(1)
elif goodKey == 2: # AAHHHHH changed
return defer.fail(error.ConchError('bad host key'))
else:
if options['host'] == self.transport.getPeer()[1]:
host = options['host']
khHost = options['host']
else:
host = '%s (%s)' % (options['host'],
self.transport.getPeer()[1])
khHost = '%s,%s' % (options['host'],
self.transport.getPeer()[1])
keyType = common.getNS(pubKey)[0]
ques = """The authenticity of host '%s' can't be established.\r
%s key fingerprint is %s.""" % (host,
{'ssh-dss':'DSA', 'ssh-rsa':'RSA'}[keyType],
fingerprint)
ques+='\r\nAre you sure you want to continue connecting (yes/no)? '
return deferredAskFrame(ques, 1).addCallback(self._cbVerifyHostKey, pubKey, khHost, keyType)
def _cbVerifyHostKey(self, ans, pubKey, khHost, keyType):
if ans.lower() not in ('yes', 'no'):
return deferredAskFrame("Please type 'yes' or 'no': ",1).addCallback(self._cbVerifyHostKey, pubKey, khHost, keyType)
if ans.lower() == 'no':
frame.write('Host key verification failed.\r\n')
raise error.ConchError('bad host key')
try:
frame.write("Warning: Permanently added '%s' (%s) to the list of known hosts.\r\n" % (khHost, {'ssh-dss':'DSA', 'ssh-rsa':'RSA'}[keyType]))
known_hosts = open(os.path.expanduser('~/.ssh/known_hosts'), 'a')
encodedKey = base64.encodestring(pubKey).replace('\n', '')
known_hosts.write('\n%s %s %s' % (khHost, keyType, encodedKey))
known_hosts.close()
except:
log.deferr()
raise error.ConchError
def connectionSecure(self):
if options['user']:
user = options['user']
else:
user = getpass.getuser()
self.requestService(SSHUserAuthClient(user, SSHConnection()))
class SSHUserAuthClient(userauth.SSHUserAuthClient):
usedFiles = []
def getPassword(self, prompt = None):
if not prompt:
prompt = "%s@%s's password: " % (self.user, options['host'])
return deferredAskFrame(prompt,0)
def getPublicKey(self):
files = [x for x in options.identitys if x not in self.usedFiles]
if not files:
return None
file = files[0]
log.msg(file)
self.usedFiles.append(file)
file = os.path.expanduser(file)
file += '.pub'
if not os.path.exists(file):
return
try:
return keys.Key.fromFile(file).blob()
except:
return self.getPublicKey() # try again
def getPrivateKey(self):
file = os.path.expanduser(self.usedFiles[-1])
if not os.path.exists(file):
return None
try:
return defer.succeed(keys.Key.fromFile(file).keyObject)
except keys.BadKeyError, e:
if e.args[0] == 'encrypted key with no password':
prompt = "Enter passphrase for key '%s': " % \
self.usedFiles[-1]
return deferredAskFrame(prompt, 0).addCallback(self._cbGetPrivateKey, 0)
def _cbGetPrivateKey(self, ans, count):
file = os.path.expanduser(self.usedFiles[-1])
try:
return keys.Key.fromFile(file, password = ans).keyObject
except keys.BadKeyError:
if count == 2:
raise
prompt = "Enter passphrase for key '%s': " % \
self.usedFiles[-1]
return deferredAskFrame(prompt, 0).addCallback(self._cbGetPrivateKey, count+1)
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
if not options['noshell']:
self.openChannel(SSHSession())
if options.localForwards:
for localPort, hostport in options.localForwards:
reactor.listenTCP(localPort,
forwarding.SSHListenForwardingFactory(self,
hostport,
forwarding.SSHListenClientForwardingChannel))
if options.remoteForwards:
for remotePort, hostport in options.remoteForwards:
log.msg('asking for remote forwarding for %s:%s' %
(remotePort, hostport))
data = forwarding.packGlobal_tcpip_forward(
('0.0.0.0', remotePort))
self.sendGlobalRequest('tcpip-forward', data)
self.remoteForwards[remotePort] = hostport
class SSHSession(channel.SSHChannel):
name = 'session'
def channelOpen(self, foo):
#global globalSession
#globalSession = self
# turn off local echo
self.escapeMode = 1
c = session.SSHSessionClient()
if options['escape']:
c.dataReceived = self.handleInput
else:
c.dataReceived = self.write
c.connectionLost = self.sendEOF
frame.callback = c.dataReceived
frame.canvas.focus_force()
if options['subsystem']:
self.conn.sendRequest(self, 'subsystem', \
common.NS(options['command']))
elif options['command']:
if options['tty']:
term = os.environ.get('TERM', 'xterm')
#winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = (25,80,0,0) #struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, '')
self.conn.sendRequest(self, 'pty-req', ptyReqData)
self.conn.sendRequest(self, 'exec', \
common.NS(options['command']))
else:
if not options['notty']:
term = os.environ.get('TERM', 'xterm')
#winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = (25,80,0,0) #struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, '')
self.conn.sendRequest(self, 'pty-req', ptyReqData)
self.conn.sendRequest(self, 'shell', '')
self.conn.transport.transport.setTcpNoDelay(1)
def handleInput(self, char):
#log.msg('handling %s' % repr(char))
if char in ('\n', '\r'):
self.escapeMode = 1
self.write(char)
elif self.escapeMode == 1 and char == options['escape']:
self.escapeMode = 2
elif self.escapeMode == 2:
self.escapeMode = 1 # so we can chain escapes together
if char == '.': # disconnect
log.msg('disconnecting from escape')
reactor.stop()
return
elif char == '\x1a': # ^Z, suspend
# following line courtesy of Erwin@freenode
os.kill(os.getpid(), signal.SIGSTOP)
return
elif char == 'R': # rekey connection
log.msg('rekeying connection')
self.conn.transport.sendKexInit()
return
self.write('~' + char)
else:
self.escapeMode = 0
self.write(char)
def dataReceived(self, data):
if options['ansilog']:
print repr(data)
frame.write(data)
def extReceived(self, t, data):
if t==connection.EXTENDED_DATA_STDERR:
log.msg('got %s stderr data' % len(data))
sys.stderr.write(data)
sys.stderr.flush()
def eofReceived(self):
log.msg('got eof')
sys.stdin.close()
def closed(self):
log.msg('closed %s' % self)
if len(self.conn.channels) == 1: # just us left
reactor.stop()
def request_exit_status(self, data):
global exitStatus
exitStatus = int(struct.unpack('>L', data)[0])
log.msg('exit status: %s' % exitStatus)
def sendEOF(self):
self.conn.sendEOF(self)
if __name__=="__main__":
run()

View File

@ -0,0 +1,10 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
An SSHv2 implementation for Twisted. Part of the Twisted.Conch package.
Maintainer: Paul Swartz
"""

View File

@ -0,0 +1,41 @@
# -*- test-case-name: twisted.conch.test.test_address -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Address object for SSH network connections.
Maintainer: Paul Swartz
@since: 12.1
"""
from zope.interface import implementer
from twisted.internet.interfaces import IAddress
from twisted.python import util
@implementer(IAddress)
class SSHTransportAddress(object, util.FancyEqMixin):
"""
Object representing an SSH Transport endpoint.
@ivar address: A instance of an object which implements I{IAddress} to
which this transport address is connected.
"""
compareAttributes = ('address',)
def __init__(self, address):
self.address = address
def __repr__(self):
return 'SSHTransportAddress(%r)' % (self.address,)
def __hash__(self):
return hash(('SSH', self.address))

View File

@ -0,0 +1,294 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implements the SSH v2 key agent protocol. This protocol is documented in the
SSH source code, in the file
U{PROTOCOL.agent<http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent>}.
Maintainer: Paul Swartz
"""
import struct
from twisted.conch.ssh.common import NS, getNS, getMP
from twisted.conch.error import ConchError, MissingKeyStoreError
from twisted.conch.ssh import keys
from twisted.internet import defer, protocol
class SSHAgentClient(protocol.Protocol):
"""
The client side of the SSH agent protocol. This is equivalent to
ssh-add(1) and can be used with either ssh-agent(1) or the SSHAgentServer
protocol, also in this package.
"""
def __init__(self):
self.buf = ''
self.deferreds = []
def dataReceived(self, data):
self.buf += data
while 1:
if len(self.buf) <= 4:
return
packLen = struct.unpack('!L', self.buf[:4])[0]
if len(self.buf) < 4 + packLen:
return
packet, self.buf = self.buf[4:4 + packLen], self.buf[4 + packLen:]
reqType = ord(packet[0])
d = self.deferreds.pop(0)
if reqType == AGENT_FAILURE:
d.errback(ConchError('agent failure'))
elif reqType == AGENT_SUCCESS:
d.callback('')
else:
d.callback(packet)
def sendRequest(self, reqType, data):
pack = struct.pack('!LB',len(data) + 1, reqType) + data
self.transport.write(pack)
d = defer.Deferred()
self.deferreds.append(d)
return d
def requestIdentities(self):
"""
@return: A L{Deferred} which will fire with a list of all keys found in
the SSH agent. The list of keys is comprised of (public key blob,
comment) tuples.
"""
d = self.sendRequest(AGENTC_REQUEST_IDENTITIES, '')
d.addCallback(self._cbRequestIdentities)
return d
def _cbRequestIdentities(self, data):
"""
Unpack a collection of identities into a list of tuples comprised of
public key blobs and comments.
"""
if ord(data[0]) != AGENT_IDENTITIES_ANSWER:
raise ConchError('unexpected response: %i' % ord(data[0]))
numKeys = struct.unpack('!L', data[1:5])[0]
result = []
data = data[5:]
for i in range(numKeys):
blob, data = getNS(data)
comment, data = getNS(data)
result.append((blob, comment))
return result
def addIdentity(self, blob, comment = ''):
"""
Add a private key blob to the agent's collection of keys.
"""
req = blob
req += NS(comment)
return self.sendRequest(AGENTC_ADD_IDENTITY, req)
def signData(self, blob, data):
"""
Request that the agent sign the given C{data} with the private key
which corresponds to the public key given by C{blob}. The private
key should have been added to the agent already.
@type blob: C{str}
@type data: C{str}
@return: A L{Deferred} which fires with a signature for given data
created with the given key.
"""
req = NS(blob)
req += NS(data)
req += '\000\000\000\000' # flags
return self.sendRequest(AGENTC_SIGN_REQUEST, req).addCallback(self._cbSignData)
def _cbSignData(self, data):
if ord(data[0]) != AGENT_SIGN_RESPONSE:
raise ConchError('unexpected data: %i' % ord(data[0]))
signature = getNS(data[1:])[0]
return signature
def removeIdentity(self, blob):
"""
Remove the private key corresponding to the public key in blob from the
running agent.
"""
req = NS(blob)
return self.sendRequest(AGENTC_REMOVE_IDENTITY, req)
def removeAllIdentities(self):
"""
Remove all keys from the running agent.
"""
return self.sendRequest(AGENTC_REMOVE_ALL_IDENTITIES, '')
class SSHAgentServer(protocol.Protocol):
"""
The server side of the SSH agent protocol. This is equivalent to
ssh-agent(1) and can be used with either ssh-add(1) or the SSHAgentClient
protocol, also in this package.
"""
def __init__(self):
self.buf = ''
def dataReceived(self, data):
self.buf += data
while 1:
if len(self.buf) <= 4:
return
packLen = struct.unpack('!L', self.buf[:4])[0]
if len(self.buf) < 4 + packLen:
return
packet, self.buf = self.buf[4:4 + packLen], self.buf[4 + packLen:]
reqType = ord(packet[0])
reqName = messages.get(reqType, None)
if not reqName:
self.sendResponse(AGENT_FAILURE, '')
else:
f = getattr(self, 'agentc_%s' % reqName)
if getattr(self.factory, 'keys', None) is None:
self.sendResponse(AGENT_FAILURE, '')
raise MissingKeyStoreError()
f(packet[1:])
def sendResponse(self, reqType, data):
pack = struct.pack('!LB', len(data) + 1, reqType) + data
self.transport.write(pack)
def agentc_REQUEST_IDENTITIES(self, data):
"""
Return all of the identities that have been added to the server
"""
assert data == ''
numKeys = len(self.factory.keys)
resp = []
resp.append(struct.pack('!L', numKeys))
for key, comment in self.factory.keys.itervalues():
resp.append(NS(key.blob())) # yes, wrapped in an NS
resp.append(NS(comment))
self.sendResponse(AGENT_IDENTITIES_ANSWER, ''.join(resp))
def agentc_SIGN_REQUEST(self, data):
"""
Data is a structure with a reference to an already added key object and
some data that the clients wants signed with that key. If the key
object wasn't loaded, return AGENT_FAILURE, else return the signature.
"""
blob, data = getNS(data)
if blob not in self.factory.keys:
return self.sendResponse(AGENT_FAILURE, '')
signData, data = getNS(data)
assert data == '\000\000\000\000'
self.sendResponse(AGENT_SIGN_RESPONSE, NS(self.factory.keys[blob][0].sign(signData)))
def agentc_ADD_IDENTITY(self, data):
"""
Adds a private key to the agent's collection of identities. On
subsequent interactions, the private key can be accessed using only the
corresponding public key.
"""
# need to pre-read the key data so we can get past it to the comment string
keyType, rest = getNS(data)
if keyType == 'ssh-rsa':
nmp = 6
elif keyType == 'ssh-dss':
nmp = 5
else:
raise keys.BadKeyError('unknown blob type: %s' % keyType)
rest = getMP(rest, nmp)[-1] # ignore the key data for now, we just want the comment
comment, rest = getNS(rest) # the comment, tacked onto the end of the key blob
k = keys.Key.fromString(data, type='private_blob') # not wrapped in NS here
self.factory.keys[k.blob()] = (k, comment)
self.sendResponse(AGENT_SUCCESS, '')
def agentc_REMOVE_IDENTITY(self, data):
"""
Remove a specific key from the agent's collection of identities.
"""
blob, _ = getNS(data)
k = keys.Key.fromString(blob, type='blob')
del self.factory.keys[k.blob()]
self.sendResponse(AGENT_SUCCESS, '')
def agentc_REMOVE_ALL_IDENTITIES(self, data):
"""
Remove all keys from the agent's collection of identities.
"""
assert data == ''
self.factory.keys = {}
self.sendResponse(AGENT_SUCCESS, '')
# v1 messages that we ignore because we don't keep v1 keys
# open-ssh sends both v1 and v2 commands, so we have to
# do no-ops for v1 commands or we'll get "bad request" errors
def agentc_REQUEST_RSA_IDENTITIES(self, data):
"""
v1 message for listing RSA1 keys; superseded by
agentc_REQUEST_IDENTITIES, which handles different key types.
"""
self.sendResponse(AGENT_RSA_IDENTITIES_ANSWER, struct.pack('!L', 0))
def agentc_REMOVE_RSA_IDENTITY(self, data):
"""
v1 message for removing RSA1 keys; superseded by
agentc_REMOVE_IDENTITY, which handles different key types.
"""
self.sendResponse(AGENT_SUCCESS, '')
def agentc_REMOVE_ALL_RSA_IDENTITIES(self, data):
"""
v1 message for removing all RSA1 keys; superseded by
agentc_REMOVE_ALL_IDENTITIES, which handles different key types.
"""
self.sendResponse(AGENT_SUCCESS, '')
AGENTC_REQUEST_RSA_IDENTITIES = 1
AGENT_RSA_IDENTITIES_ANSWER = 2
AGENT_FAILURE = 5
AGENT_SUCCESS = 6
AGENTC_REMOVE_RSA_IDENTITY = 8
AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
AGENTC_REQUEST_IDENTITIES = 11
AGENT_IDENTITIES_ANSWER = 12
AGENTC_SIGN_REQUEST = 13
AGENT_SIGN_RESPONSE = 14
AGENTC_ADD_IDENTITY = 17
AGENTC_REMOVE_IDENTITY = 18
AGENTC_REMOVE_ALL_IDENTITIES = 19
messages = {}
for name, value in locals().copy().items():
if name[:7] == 'AGENTC_':
messages[value] = name[7:] # doesn't handle doubles

View File

@ -0,0 +1,281 @@
# -*- test-case-name: twisted.conch.test.test_channel -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
The parent class for all the SSH Channels. Currently implemented channels
are session. direct-tcp, and forwarded-tcp.
Maintainer: Paul Swartz
"""
from zope.interface import implementer
from twisted.python import log
from twisted.internet import interfaces
@implementer(interfaces.ITransport)
class SSHChannel(log.Logger):
"""
A class that represents a multiplexed channel over an SSH connection.
The channel has a local window which is the maximum amount of data it will
receive, and a remote which is the maximum amount of data the remote side
will accept. There is also a maximum packet size for any individual data
packet going each way.
@ivar name: the name of the channel.
@type name: C{str}
@ivar localWindowSize: the maximum size of the local window in bytes.
@type localWindowSize: C{int}
@ivar localWindowLeft: how many bytes are left in the local window.
@type localWindowLeft: C{int}
@ivar localMaxPacket: the maximum size of packet we will accept in bytes.
@type localMaxPacket: C{int}
@ivar remoteWindowLeft: how many bytes are left in the remote window.
@type remoteWindowLeft: C{int}
@ivar remoteMaxPacket: the maximum size of a packet the remote side will
accept in bytes.
@type remoteMaxPacket: C{int}
@ivar conn: the connection this channel is multiplexed through.
@type conn: L{SSHConnection}
@ivar data: any data to send to the other size when the channel is
requested.
@type data: C{str}
@ivar avatar: an avatar for the logged-in user (if a server channel)
@ivar localClosed: True if we aren't accepting more data.
@type localClosed: C{bool}
@ivar remoteClosed: True if the other size isn't accepting more data.
@type remoteClosed: C{bool}
"""
name = None # only needed for client channels
def __init__(self, localWindow = 0, localMaxPacket = 0,
remoteWindow = 0, remoteMaxPacket = 0,
conn = None, data=None, avatar = None):
self.localWindowSize = localWindow or 131072
self.localWindowLeft = self.localWindowSize
self.localMaxPacket = localMaxPacket or 32768
self.remoteWindowLeft = remoteWindow
self.remoteMaxPacket = remoteMaxPacket
self.areWriting = 1
self.conn = conn
self.data = data
self.avatar = avatar
self.specificData = ''
self.buf = ''
self.extBuf = []
self.closing = 0
self.localClosed = 0
self.remoteClosed = 0
self.id = None # gets set later by SSHConnection
def __str__(self):
return '<SSHChannel %s (lw %i rw %i)>' % (self.name,
self.localWindowLeft, self.remoteWindowLeft)
def logPrefix(self):
id = (self.id is not None and str(self.id)) or "unknown"
return "SSHChannel %s (%s) on %s" % (self.name, id,
self.conn.logPrefix())
def channelOpen(self, specificData):
"""
Called when the channel is opened. specificData is any data that the
other side sent us when opening the channel.
@type specificData: C{str}
"""
log.msg('channel open')
def openFailed(self, reason):
"""
Called when the open failed for some reason.
reason.desc is a string descrption, reason.code the SSH error code.
@type reason: L{error.ConchError}
"""
log.msg('other side refused open\nreason: %s'% reason)
def addWindowBytes(self, bytes):
"""
Called when bytes are added to the remote window. By default it clears
the data buffers.
@type bytes: C{int}
"""
self.remoteWindowLeft = self.remoteWindowLeft+bytes
if not self.areWriting and not self.closing:
self.areWriting = True
self.startWriting()
if self.buf:
b = self.buf
self.buf = ''
self.write(b)
if self.extBuf:
b = self.extBuf
self.extBuf = []
for (type, data) in b:
self.writeExtended(type, data)
def requestReceived(self, requestType, data):
"""
Called when a request is sent to this channel. By default it delegates
to self.request_<requestType>.
If this function returns true, the request succeeded, otherwise it
failed.
@type requestType: C{str}
@type data: C{str}
@rtype: C{bool}
"""
foo = requestType.replace('-', '_')
f = getattr(self, 'request_%s'%foo, None)
if f:
return f(data)
log.msg('unhandled request for %s'%requestType)
return 0
def dataReceived(self, data):
"""
Called when we receive data.
@type data: C{str}
"""
log.msg('got data %s'%repr(data))
def extReceived(self, dataType, data):
"""
Called when we receive extended data (usually standard error).
@type dataType: C{int}
@type data: C{str}
"""
log.msg('got extended data %s %s'%(dataType, repr(data)))
def eofReceived(self):
"""
Called when the other side will send no more data.
"""
log.msg('remote eof')
def closeReceived(self):
"""
Called when the other side has closed the channel.
"""
log.msg('remote close')
self.loseConnection()
def closed(self):
"""
Called when the channel is closed. This means that both our side and
the remote side have closed the channel.
"""
log.msg('closed')
# transport stuff
def write(self, data):
"""
Write some data to the channel. If there is not enough remote window
available, buffer until it is. Otherwise, split the data into
packets of length remoteMaxPacket and send them.
@type data: C{str}
"""
if self.buf:
self.buf += data
return
top = len(data)
if top > self.remoteWindowLeft:
data, self.buf = (data[:self.remoteWindowLeft],
data[self.remoteWindowLeft:])
self.areWriting = 0
self.stopWriting()
top = self.remoteWindowLeft
rmp = self.remoteMaxPacket
write = self.conn.sendData
r = range(0, top, rmp)
for offset in r:
write(self, data[offset: offset+rmp])
self.remoteWindowLeft -= top
if self.closing and not self.buf:
self.loseConnection() # try again
def writeExtended(self, dataType, data):
"""
Send extended data to this channel. If there is not enough remote
window available, buffer until there is. Otherwise, split the data
into packets of length remoteMaxPacket and send them.
@type dataType: C{int}
@type data: C{str}
"""
if self.extBuf:
if self.extBuf[-1][0] == dataType:
self.extBuf[-1][1] += data
else:
self.extBuf.append([dataType, data])
return
if len(data) > self.remoteWindowLeft:
data, self.extBuf = (data[:self.remoteWindowLeft],
[[dataType, data[self.remoteWindowLeft:]]])
self.areWriting = 0
self.stopWriting()
while len(data) > self.remoteMaxPacket:
self.conn.sendExtendedData(self, dataType,
data[:self.remoteMaxPacket])
data = data[self.remoteMaxPacket:]
self.remoteWindowLeft -= self.remoteMaxPacket
if data:
self.conn.sendExtendedData(self, dataType, data)
self.remoteWindowLeft -= len(data)
if self.closing:
self.loseConnection() # try again
def writeSequence(self, data):
"""
Part of the Transport interface. Write a list of strings to the
channel.
@type data: C{list} of C{str}
"""
self.write(''.join(data))
def loseConnection(self):
"""
Close the channel if there is no buferred data. Otherwise, note the
request and return.
"""
self.closing = 1
if not self.buf and not self.extBuf:
self.conn.sendClose(self)
def getPeer(self):
"""
Return a tuple describing the other side of the connection.
@rtype: C{tuple}
"""
return('SSH', )+self.conn.transport.getPeer()
def getHost(self):
"""
Return a tuple describing our side of the connection.
@rtype: C{tuple}
"""
return('SSH', )+self.conn.transport.getHost()
def stopWriting(self):
"""
Called when the remote buffer is full, as a hint to stop writing.
This can be ignored, but it can be helpful.
"""
def startWriting(self):
"""
Called when the remote buffer has more room, as a hint to continue
writing.
"""

View File

@ -0,0 +1,116 @@
# -*- test-case-name: twisted.conch.test.test_ssh -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Common functions for the SSH classes.
Maintainer: Paul Swartz
"""
import struct, warnings, __builtin__
try:
from Crypto import Util
except ImportError:
warnings.warn("PyCrypto not installed, but continuing anyways!",
RuntimeWarning)
def NS(t):
"""
net string
"""
return struct.pack('!L',len(t)) + t
def getNS(s, count=1):
"""
get net string
"""
ns = []
c = 0
for i in range(count):
l, = struct.unpack('!L',s[c:c+4])
ns.append(s[c+4:4+l+c])
c += 4 + l
return tuple(ns) + (s[c:],)
def MP(number):
if number==0: return '\000'*4
assert number>0
bn = Util.number.long_to_bytes(number)
if ord(bn[0])&128:
bn = '\000' + bn
return struct.pack('>L',len(bn)) + bn
def getMP(data, count=1):
"""
Get multiple precision integer out of the string. A multiple precision
integer is stored as a 4-byte length followed by length bytes of the
integer. If count is specified, get count integers out of the string.
The return value is a tuple of count integers followed by the rest of
the data.
"""
mp = []
c = 0
for i in range(count):
length, = struct.unpack('>L',data[c:c+4])
mp.append(Util.number.bytes_to_long(data[c+4:c+4+length]))
c += 4 + length
return tuple(mp) + (data[c:],)
def _MPpow(x, y, z):
"""return the MP version of (x**y)%z
"""
return MP(pow(x,y,z))
def ffs(c, s):
"""
first from second
goes through the first list, looking for items in the second, returns the first one
"""
for i in c:
if i in s: return i
getMP_py = getMP
MP_py = MP
_MPpow_py = _MPpow
pyPow = pow
def _fastgetMP(data, count=1):
mp = []
c = 0
for i in range(count):
length = struct.unpack('!L', data[c:c+4])[0]
mp.append(long(gmpy.mpz(data[c + 4:c + 4 + length][::-1] + '\x00', 256)))
c += length + 4
return tuple(mp) + (data[c:],)
def _fastMP(i):
i2 = gmpy.mpz(i).binary()[::-1]
return struct.pack('!L', len(i2)) + i2
def _fastMPpow(x, y, z=None):
r = pyPow(gmpy.mpz(x),y,z).binary()[::-1]
return struct.pack('!L', len(r)) + r
def install():
global getMP, MP, _MPpow
getMP = _fastgetMP
MP = _fastMP
_MPpow = _fastMPpow
# XXX: We override builtin pow so that PyCrypto can benefit from gmpy too.
def _fastpow(x, y, z=None, mpz=gmpy.mpz):
if type(x) in (long, int):
x = mpz(x)
return pyPow(x, y, z)
__builtin__.pow = _fastpow # evil evil
try:
import gmpy
install()
except ImportError:
pass

View File

@ -0,0 +1,636 @@
# -*- test-case-name: twisted.conch.test.test_connection -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of the ssh-connection service, which
allows access to the shell and port-forwarding.
Maintainer: Paul Swartz
"""
import struct
from twisted.conch.ssh import service, common
from twisted.conch import error
from twisted.internet import defer
from twisted.python import log
class SSHConnection(service.SSHService):
"""
An implementation of the 'ssh-connection' service. It is used to
multiplex multiple channels over the single SSH connection.
@ivar localChannelID: the next number to use as a local channel ID.
@type localChannelID: C{int}
@ivar channels: a C{dict} mapping a local channel ID to C{SSHChannel}
subclasses.
@type channels: C{dict}
@ivar localToRemoteChannel: a C{dict} mapping a local channel ID to a
remote channel ID.
@type localToRemoteChannel: C{dict}
@ivar channelsToRemoteChannel: a C{dict} mapping a C{SSHChannel} subclass
to remote channel ID.
@type channelsToRemoteChannel: C{dict}
@ivar deferreds: a C{dict} mapping a local channel ID to a C{list} of
C{Deferreds} for outstanding channel requests. Also, the 'global'
key stores the C{list} of pending global request C{Deferred}s.
"""
name = 'ssh-connection'
def __init__(self):
self.localChannelID = 0 # this is the current # to use for channel ID
self.localToRemoteChannel = {} # local channel ID -> remote channel ID
self.channels = {} # local channel ID -> subclass of SSHChannel
self.channelsToRemoteChannel = {} # subclass of SSHChannel ->
# remote channel ID
self.deferreds = {"global": []} # local channel -> list of deferreds
# for pending requests or 'global' -> list of
# deferreds for global requests
self.transport = None # gets set later
def serviceStarted(self):
if hasattr(self.transport, 'avatar'):
self.transport.avatar.conn = self
def serviceStopped(self):
"""
Called when the connection is stopped.
"""
map(self.channelClosed, self.channels.values())
self._cleanupGlobalDeferreds()
def _cleanupGlobalDeferreds(self):
"""
All pending requests that have returned a deferred must be errbacked
when this service is stopped, otherwise they might be left uncalled and
uncallable.
"""
for d in self.deferreds["global"]:
d.errback(error.ConchError("Connection stopped."))
del self.deferreds["global"][:]
# packet methods
def ssh_GLOBAL_REQUEST(self, packet):
"""
The other side has made a global request. Payload::
string request type
bool want reply
<request specific data>
This dispatches to self.gotGlobalRequest.
"""
requestType, rest = common.getNS(packet)
wantReply, rest = ord(rest[0]), rest[1:]
ret = self.gotGlobalRequest(requestType, rest)
if wantReply:
reply = MSG_REQUEST_FAILURE
data = ''
if ret:
reply = MSG_REQUEST_SUCCESS
if isinstance(ret, (tuple, list)):
data = ret[1]
self.transport.sendPacket(reply, data)
def ssh_REQUEST_SUCCESS(self, packet):
"""
Our global request succeeded. Get the appropriate Deferred and call
it back with the packet we received.
"""
log.msg('RS')
self.deferreds['global'].pop(0).callback(packet)
def ssh_REQUEST_FAILURE(self, packet):
"""
Our global request failed. Get the appropriate Deferred and errback
it with the packet we received.
"""
log.msg('RF')
self.deferreds['global'].pop(0).errback(
error.ConchError('global request failed', packet))
def ssh_CHANNEL_OPEN(self, packet):
"""
The other side wants to get a channel. Payload::
string channel name
uint32 remote channel number
uint32 remote window size
uint32 remote maximum packet size
<channel specific data>
We get a channel from self.getChannel(), give it a local channel number
and notify the other side. Then notify the channel by calling its
channelOpen method.
"""
channelType, rest = common.getNS(packet)
senderChannel, windowSize, maxPacket = struct.unpack('>3L', rest[:12])
packet = rest[12:]
try:
channel = self.getChannel(channelType, windowSize, maxPacket,
packet)
localChannel = self.localChannelID
self.localChannelID += 1
channel.id = localChannel
self.channels[localChannel] = channel
self.channelsToRemoteChannel[channel] = senderChannel
self.localToRemoteChannel[localChannel] = senderChannel
self.transport.sendPacket(MSG_CHANNEL_OPEN_CONFIRMATION,
struct.pack('>4L', senderChannel, localChannel,
channel.localWindowSize,
channel.localMaxPacket)+channel.specificData)
log.callWithLogger(channel, channel.channelOpen, packet)
except Exception, e:
log.err(e, 'channel open failed')
if isinstance(e, error.ConchError):
textualInfo, reason = e.args
if isinstance(textualInfo, (int, long)):
# See #3657 and #3071
textualInfo, reason = reason, textualInfo
else:
reason = OPEN_CONNECT_FAILED
textualInfo = "unknown failure"
self.transport.sendPacket(
MSG_CHANNEL_OPEN_FAILURE,
struct.pack('>2L', senderChannel, reason) +
common.NS(textualInfo) + common.NS(''))
def ssh_CHANNEL_OPEN_CONFIRMATION(self, packet):
"""
The other side accepted our MSG_CHANNEL_OPEN request. Payload::
uint32 local channel number
uint32 remote channel number
uint32 remote window size
uint32 remote maximum packet size
<channel specific data>
Find the channel using the local channel number and notify its
channelOpen method.
"""
(localChannel, remoteChannel, windowSize,
maxPacket) = struct.unpack('>4L', packet[: 16])
specificData = packet[16:]
channel = self.channels[localChannel]
channel.conn = self
self.localToRemoteChannel[localChannel] = remoteChannel
self.channelsToRemoteChannel[channel] = remoteChannel
channel.remoteWindowLeft = windowSize
channel.remoteMaxPacket = maxPacket
log.callWithLogger(channel, channel.channelOpen, specificData)
def ssh_CHANNEL_OPEN_FAILURE(self, packet):
"""
The other side did not accept our MSG_CHANNEL_OPEN request. Payload::
uint32 local channel number
uint32 reason code
string reason description
Find the channel using the local channel number and notify it by
calling its openFailed() method.
"""
localChannel, reasonCode = struct.unpack('>2L', packet[:8])
reasonDesc = common.getNS(packet[8:])[0]
channel = self.channels[localChannel]
del self.channels[localChannel]
channel.conn = self
reason = error.ConchError(reasonDesc, reasonCode)
log.callWithLogger(channel, channel.openFailed, reason)
def ssh_CHANNEL_WINDOW_ADJUST(self, packet):
"""
The other side is adding bytes to its window. Payload::
uint32 local channel number
uint32 bytes to add
Call the channel's addWindowBytes() method to add new bytes to the
remote window.
"""
localChannel, bytesToAdd = struct.unpack('>2L', packet[:8])
channel = self.channels[localChannel]
log.callWithLogger(channel, channel.addWindowBytes, bytesToAdd)
def ssh_CHANNEL_DATA(self, packet):
"""
The other side is sending us data. Payload::
uint32 local channel number
string data
Check to make sure the other side hasn't sent too much data (more
than what's in the window, or more than the maximum packet size). If
they have, close the channel. Otherwise, decrease the available
window and pass the data to the channel's dataReceived().
"""
localChannel, dataLength = struct.unpack('>2L', packet[:8])
channel = self.channels[localChannel]
# XXX should this move to dataReceived to put client in charge?
if (dataLength > channel.localWindowLeft or
dataLength > channel.localMaxPacket): # more data than we want
log.callWithLogger(channel, log.msg, 'too much data')
self.sendClose(channel)
return
#packet = packet[:channel.localWindowLeft+4]
data = common.getNS(packet[4:])[0]
channel.localWindowLeft -= dataLength
if channel.localWindowLeft < channel.localWindowSize // 2:
self.adjustWindow(channel, channel.localWindowSize - \
channel.localWindowLeft)
#log.msg('local window left: %s/%s' % (channel.localWindowLeft,
# channel.localWindowSize))
log.callWithLogger(channel, channel.dataReceived, data)
def ssh_CHANNEL_EXTENDED_DATA(self, packet):
"""
The other side is sending us exteneded data. Payload::
uint32 local channel number
uint32 type code
string data
Check to make sure the other side hasn't sent too much data (more
than what's in the window, or than the maximum packet size). If
they have, close the channel. Otherwise, decrease the available
window and pass the data and type code to the channel's
extReceived().
"""
localChannel, typeCode, dataLength = struct.unpack('>3L', packet[:12])
channel = self.channels[localChannel]
if (dataLength > channel.localWindowLeft or
dataLength > channel.localMaxPacket):
log.callWithLogger(channel, log.msg, 'too much extdata')
self.sendClose(channel)
return
data = common.getNS(packet[8:])[0]
channel.localWindowLeft -= dataLength
if channel.localWindowLeft < channel.localWindowSize // 2:
self.adjustWindow(channel, channel.localWindowSize -
channel.localWindowLeft)
log.callWithLogger(channel, channel.extReceived, typeCode, data)
def ssh_CHANNEL_EOF(self, packet):
"""
The other side is not sending any more data. Payload::
uint32 local channel number
Notify the channel by calling its eofReceived() method.
"""
localChannel = struct.unpack('>L', packet[:4])[0]
channel = self.channels[localChannel]
log.callWithLogger(channel, channel.eofReceived)
def ssh_CHANNEL_CLOSE(self, packet):
"""
The other side is closing its end; it does not want to receive any
more data. Payload::
uint32 local channel number
Notify the channnel by calling its closeReceived() method. If
the channel has also sent a close message, call self.channelClosed().
"""
localChannel = struct.unpack('>L', packet[:4])[0]
channel = self.channels[localChannel]
log.callWithLogger(channel, channel.closeReceived)
channel.remoteClosed = True
if channel.localClosed and channel.remoteClosed:
self.channelClosed(channel)
def ssh_CHANNEL_REQUEST(self, packet):
"""
The other side is sending a request to a channel. Payload::
uint32 local channel number
string request name
bool want reply
<request specific data>
Pass the message to the channel's requestReceived method. If the
other side wants a reply, add callbacks which will send the
reply.
"""
localChannel = struct.unpack('>L', packet[:4])[0]
requestType, rest = common.getNS(packet[4:])
wantReply = ord(rest[0])
channel = self.channels[localChannel]
d = defer.maybeDeferred(log.callWithLogger, channel,
channel.requestReceived, requestType, rest[1:])
if wantReply:
d.addCallback(self._cbChannelRequest, localChannel)
d.addErrback(self._ebChannelRequest, localChannel)
return d
def _cbChannelRequest(self, result, localChannel):
"""
Called back if the other side wanted a reply to a channel request. If
the result is true, send a MSG_CHANNEL_SUCCESS. Otherwise, raise
a C{error.ConchError}
@param result: the value returned from the channel's requestReceived()
method. If it's False, the request failed.
@type result: C{bool}
@param localChannel: the local channel ID of the channel to which the
request was made.
@type localChannel: C{int}
@raises ConchError: if the result is False.
"""
if not result:
raise error.ConchError('failed request')
self.transport.sendPacket(MSG_CHANNEL_SUCCESS, struct.pack('>L',
self.localToRemoteChannel[localChannel]))
def _ebChannelRequest(self, result, localChannel):
"""
Called if the other wisde wanted a reply to the channel requeset and
the channel request failed.
@param result: a Failure, but it's not used.
@param localChannel: the local channel ID of the channel to which the
request was made.
@type localChannel: C{int}
"""
self.transport.sendPacket(MSG_CHANNEL_FAILURE, struct.pack('>L',
self.localToRemoteChannel[localChannel]))
def ssh_CHANNEL_SUCCESS(self, packet):
"""
Our channel request to the other side succeeded. Payload::
uint32 local channel number
Get the C{Deferred} out of self.deferreds and call it back.
"""
localChannel = struct.unpack('>L', packet[:4])[0]
if self.deferreds.get(localChannel):
d = self.deferreds[localChannel].pop(0)
log.callWithLogger(self.channels[localChannel],
d.callback, '')
def ssh_CHANNEL_FAILURE(self, packet):
"""
Our channel request to the other side failed. Payload::
uint32 local channel number
Get the C{Deferred} out of self.deferreds and errback it with a
C{error.ConchError}.
"""
localChannel = struct.unpack('>L', packet[:4])[0]
if self.deferreds.get(localChannel):
d = self.deferreds[localChannel].pop(0)
log.callWithLogger(self.channels[localChannel],
d.errback,
error.ConchError('channel request failed'))
# methods for users of the connection to call
def sendGlobalRequest(self, request, data, wantReply=0):
"""
Send a global request for this connection. Current this is only used
for remote->local TCP forwarding.
@type request: C{str}
@type data: C{str}
@type wantReply: C{bool}
@rtype C{Deferred}/C{None}
"""
self.transport.sendPacket(MSG_GLOBAL_REQUEST,
common.NS(request)
+ (wantReply and '\xff' or '\x00')
+ data)
if wantReply:
d = defer.Deferred()
self.deferreds['global'].append(d)
return d
def openChannel(self, channel, extra=''):
"""
Open a new channel on this connection.
@type channel: subclass of C{SSHChannel}
@type extra: C{str}
"""
log.msg('opening channel %s with %s %s'%(self.localChannelID,
channel.localWindowSize, channel.localMaxPacket))
self.transport.sendPacket(MSG_CHANNEL_OPEN, common.NS(channel.name)
+ struct.pack('>3L', self.localChannelID,
channel.localWindowSize, channel.localMaxPacket)
+ extra)
channel.id = self.localChannelID
self.channels[self.localChannelID] = channel
self.localChannelID += 1
def sendRequest(self, channel, requestType, data, wantReply=0):
"""
Send a request to a channel.
@type channel: subclass of C{SSHChannel}
@type requestType: C{str}
@type data: C{str}
@type wantReply: C{bool}
@rtype C{Deferred}/C{None}
"""
if channel.localClosed:
return
log.msg('sending request %s' % requestType)
self.transport.sendPacket(MSG_CHANNEL_REQUEST, struct.pack('>L',
self.channelsToRemoteChannel[channel])
+ common.NS(requestType)+chr(wantReply)
+ data)
if wantReply:
d = defer.Deferred()
self.deferreds.setdefault(channel.id, []).append(d)
return d
def adjustWindow(self, channel, bytesToAdd):
"""
Tell the other side that we will receive more data. This should not
normally need to be called as it is managed automatically.
@type channel: subclass of L{SSHChannel}
@type bytesToAdd: C{int}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(MSG_CHANNEL_WINDOW_ADJUST, struct.pack('>2L',
self.channelsToRemoteChannel[channel],
bytesToAdd))
log.msg('adding %i to %i in channel %i' % (bytesToAdd,
channel.localWindowLeft, channel.id))
channel.localWindowLeft += bytesToAdd
def sendData(self, channel, data):
"""
Send data to a channel. This should not normally be used: instead use
channel.write(data) as it manages the window automatically.
@type channel: subclass of L{SSHChannel}
@type data: C{str}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(MSG_CHANNEL_DATA, struct.pack('>L',
self.channelsToRemoteChannel[channel]) +
common.NS(data))
def sendExtendedData(self, channel, dataType, data):
"""
Send extended data to a channel. This should not normally be used:
instead use channel.writeExtendedData(data, dataType) as it manages
the window automatically.
@type channel: subclass of L{SSHChannel}
@type dataType: C{int}
@type data: C{str}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(MSG_CHANNEL_EXTENDED_DATA, struct.pack('>2L',
self.channelsToRemoteChannel[channel],dataType) \
+ common.NS(data))
def sendEOF(self, channel):
"""
Send an EOF (End of File) for a channel.
@type channel: subclass of L{SSHChannel}
"""
if channel.localClosed:
return # we're already closed
log.msg('sending eof')
self.transport.sendPacket(MSG_CHANNEL_EOF, struct.pack('>L',
self.channelsToRemoteChannel[channel]))
def sendClose(self, channel):
"""
Close a channel.
@type channel: subclass of L{SSHChannel}
"""
if channel.localClosed:
return # we're already closed
log.msg('sending close %i' % channel.id)
self.transport.sendPacket(MSG_CHANNEL_CLOSE, struct.pack('>L',
self.channelsToRemoteChannel[channel]))
channel.localClosed = True
if channel.localClosed and channel.remoteClosed:
self.channelClosed(channel)
# methods to override
def getChannel(self, channelType, windowSize, maxPacket, data):
"""
The other side requested a channel of some sort.
channelType is the type of channel being requested,
windowSize is the initial size of the remote window,
maxPacket is the largest packet we should send,
data is any other packet data (often nothing).
We return a subclass of L{SSHChannel}.
By default, this dispatches to a method 'channel_channelType' with any
non-alphanumerics in the channelType replace with _'s. If it cannot
find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
The method is called with arguments of windowSize, maxPacket, data.
@type channelType: C{str}
@type windowSize: C{int}
@type maxPacket: C{int}
@type data: C{str}
@rtype: subclass of L{SSHChannel}/C{tuple}
"""
log.msg('got channel %s request' % channelType)
if hasattr(self.transport, "avatar"): # this is a server!
chan = self.transport.avatar.lookupChannel(channelType,
windowSize,
maxPacket,
data)
else:
channelType = channelType.translate(TRANSLATE_TABLE)
f = getattr(self, 'channel_%s' % channelType, None)
if f is not None:
chan = f(windowSize, maxPacket, data)
else:
chan = None
if chan is None:
raise error.ConchError('unknown channel',
OPEN_UNKNOWN_CHANNEL_TYPE)
else:
chan.conn = self
return chan
def gotGlobalRequest(self, requestType, data):
"""
We got a global request. pretty much, this is just used by the client
to request that we forward a port from the server to the client.
Returns either:
- 1: request accepted
- 1, <data>: request accepted with request specific data
- 0: request denied
By default, this dispatches to a method 'global_requestType' with
-'s in requestType replaced with _'s. The found method is passed data.
If this method cannot be found, this method returns 0. Otherwise, it
returns the return value of that method.
@type requestType: C{str}
@type data: C{str}
@rtype: C{int}/C{tuple}
"""
log.msg('got global %s request' % requestType)
if hasattr(self.transport, 'avatar'): # this is a server!
return self.transport.avatar.gotGlobalRequest(requestType, data)
requestType = requestType.replace('-','_')
f = getattr(self, 'global_%s' % requestType, None)
if not f:
return 0
return f(data)
def channelClosed(self, channel):
"""
Called when a channel is closed.
It clears the local state related to the channel, and calls
channel.closed().
MAKE SURE YOU CALL THIS METHOD, even if you subclass L{SSHConnection}.
If you don't, things will break mysteriously.
@type channel: L{SSHChannel}
"""
if channel in self.channelsToRemoteChannel: # actually open
channel.localClosed = channel.remoteClosed = True
del self.localToRemoteChannel[channel.id]
del self.channels[channel.id]
del self.channelsToRemoteChannel[channel]
for d in self.deferreds.setdefault(channel.id, []):
d.errback(error.ConchError("Channel closed."))
del self.deferreds[channel.id][:]
log.callWithLogger(channel, channel.closed)
MSG_GLOBAL_REQUEST = 80
MSG_REQUEST_SUCCESS = 81
MSG_REQUEST_FAILURE = 82
MSG_CHANNEL_OPEN = 90
MSG_CHANNEL_OPEN_CONFIRMATION = 91
MSG_CHANNEL_OPEN_FAILURE = 92
MSG_CHANNEL_WINDOW_ADJUST = 93
MSG_CHANNEL_DATA = 94
MSG_CHANNEL_EXTENDED_DATA = 95
MSG_CHANNEL_EOF = 96
MSG_CHANNEL_CLOSE = 97
MSG_CHANNEL_REQUEST = 98
MSG_CHANNEL_SUCCESS = 99
MSG_CHANNEL_FAILURE = 100
OPEN_ADMINISTRATIVELY_PROHIBITED = 1
OPEN_CONNECT_FAILED = 2
OPEN_UNKNOWN_CHANNEL_TYPE = 3
OPEN_RESOURCE_SHORTAGE = 4
EXTENDED_DATA_STDERR = 1
messages = {}
for name, value in locals().copy().items():
if name[:4] == 'MSG_':
messages[value] = name # doesn't handle doubles
import string
alphanums = string.letters + string.digits
TRANSLATE_TABLE = ''.join([chr(i) in alphanums and chr(i) or '_'
for i in range(256)])
SSHConnection.protocolMessages = messages

View File

@ -0,0 +1,120 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A Factory for SSH servers, along with an OpenSSHFactory to use the same
data sources as OpenSSH.
Maintainer: Paul Swartz
"""
from twisted.internet import protocol
from twisted.python import log
from twisted.conch import error
import transport, userauth, connection
import random
class SSHFactory(protocol.Factory):
"""
A Factory for SSH servers.
"""
protocol = transport.SSHServerTransport
services = {
'ssh-userauth':userauth.SSHUserAuthServer,
'ssh-connection':connection.SSHConnection
}
def startFactory(self):
"""
Check for public and private keys.
"""
if not hasattr(self,'publicKeys'):
self.publicKeys = self.getPublicKeys()
if not hasattr(self,'privateKeys'):
self.privateKeys = self.getPrivateKeys()
if not self.publicKeys or not self.privateKeys:
raise error.ConchError('no host keys, failing')
if not hasattr(self,'primes'):
self.primes = self.getPrimes()
def buildProtocol(self, addr):
"""
Create an instance of the server side of the SSH protocol.
@type addr: L{twisted.internet.interfaces.IAddress} provider
@param addr: The address at which the server will listen.
@rtype: L{twisted.conch.ssh.SSHServerTransport}
@return: The built transport.
"""
t = protocol.Factory.buildProtocol(self, addr)
t.supportedPublicKeys = self.privateKeys.keys()
if not self.primes:
log.msg('disabling diffie-hellman-group-exchange because we '
'cannot find moduli file')
ske = t.supportedKeyExchanges[:]
ske.remove('diffie-hellman-group-exchange-sha1')
t.supportedKeyExchanges = ske
return t
def getPublicKeys(self):
"""
Called when the factory is started to get the public portions of the
servers host keys. Returns a dictionary mapping SSH key types to
public key strings.
@rtype: C{dict}
"""
raise NotImplementedError('getPublicKeys unimplemented')
def getPrivateKeys(self):
"""
Called when the factory is started to get the private portions of the
servers host keys. Returns a dictionary mapping SSH key types to
C{Crypto.PublicKey.pubkey.pubkey} objects.
@rtype: C{dict}
"""
raise NotImplementedError('getPrivateKeys unimplemented')
def getPrimes(self):
"""
Called when the factory is started to get Diffie-Hellman generators and
primes to use. Returns a dictionary mapping number of bits to lists
of tuple of (generator, prime).
@rtype: C{dict}
"""
def getDHPrime(self, bits):
"""
Return a tuple of (g, p) for a Diffe-Hellman process, with p being as
close to bits bits as possible.
@type bits: C{int}
@rtype: C{tuple}
"""
primesKeys = self.primes.keys()
primesKeys.sort(lambda x, y: cmp(abs(x - bits), abs(y - bits)))
realBits = primesKeys[0]
return random.choice(self.primes[realBits])
def getService(self, transport, service):
"""
Return a class to use as a service for the given transport.
@type transport: L{transport.SSHServerTransport}
@type service: C{str}
@rtype: subclass of L{service.SSHService}
"""
if service == 'ssh-userauth' or hasattr(transport, 'avatar'):
return self.services[service]

View File

@ -0,0 +1,933 @@
# -*- test-case-name: twisted.conch.test.test_filetransfer -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import errno
import struct
from zope.interface import implementer
from twisted.conch.interfaces import ISFTPServer, ISFTPFile
from twisted.conch.ssh.common import NS, getNS
from twisted.internet import defer, protocol
from twisted.python import failure, log
class FileTransferBase(protocol.Protocol):
versions = (3, )
packetTypes = {}
def __init__(self):
self.buf = ''
self.otherVersion = None # this gets set
def sendPacket(self, kind, data):
self.transport.write(struct.pack('!LB', len(data)+1, kind) + data)
def dataReceived(self, data):
self.buf += data
while len(self.buf) > 5:
length, kind = struct.unpack('!LB', self.buf[:5])
if len(self.buf) < 4 + length:
return
data, self.buf = self.buf[5:4+length], self.buf[4+length:]
packetType = self.packetTypes.get(kind, None)
if not packetType:
log.msg('no packet type for', kind)
continue
f = getattr(self, 'packet_%s' % packetType, None)
if not f:
log.msg('not implemented: %s' % packetType)
log.msg(repr(data[4:]))
reqId, = struct.unpack('!L', data[:4])
self._sendStatus(reqId, FX_OP_UNSUPPORTED,
"don't understand %s" % packetType)
#XXX not implemented
continue
try:
f(data)
except Exception:
log.err()
continue
def _parseAttributes(self, data):
flags ,= struct.unpack('!L', data[:4])
attrs = {}
data = data[4:]
if flags & FILEXFER_ATTR_SIZE == FILEXFER_ATTR_SIZE:
size ,= struct.unpack('!Q', data[:8])
attrs['size'] = size
data = data[8:]
if flags & FILEXFER_ATTR_OWNERGROUP == FILEXFER_ATTR_OWNERGROUP:
uid, gid = struct.unpack('!2L', data[:8])
attrs['uid'] = uid
attrs['gid'] = gid
data = data[8:]
if flags & FILEXFER_ATTR_PERMISSIONS == FILEXFER_ATTR_PERMISSIONS:
perms ,= struct.unpack('!L', data[:4])
attrs['permissions'] = perms
data = data[4:]
if flags & FILEXFER_ATTR_ACMODTIME == FILEXFER_ATTR_ACMODTIME:
atime, mtime = struct.unpack('!2L', data[:8])
attrs['atime'] = atime
attrs['mtime'] = mtime
data = data[8:]
if flags & FILEXFER_ATTR_EXTENDED == FILEXFER_ATTR_EXTENDED:
extended_count ,= struct.unpack('!L', data[:4])
data = data[4:]
for i in xrange(extended_count):
extended_type, data = getNS(data)
extended_data, data = getNS(data)
attrs['ext_%s' % extended_type] = extended_data
return attrs, data
def _packAttributes(self, attrs):
flags = 0
data = ''
if 'size' in attrs:
data += struct.pack('!Q', attrs['size'])
flags |= FILEXFER_ATTR_SIZE
if 'uid' in attrs and 'gid' in attrs:
data += struct.pack('!2L', attrs['uid'], attrs['gid'])
flags |= FILEXFER_ATTR_OWNERGROUP
if 'permissions' in attrs:
data += struct.pack('!L', attrs['permissions'])
flags |= FILEXFER_ATTR_PERMISSIONS
if 'atime' in attrs and 'mtime' in attrs:
data += struct.pack('!2L', attrs['atime'], attrs['mtime'])
flags |= FILEXFER_ATTR_ACMODTIME
extended = []
for k in attrs:
if k.startswith('ext_'):
ext_type = NS(k[4:])
ext_data = NS(attrs[k])
extended.append(ext_type+ext_data)
if extended:
data += struct.pack('!L', len(extended))
data += ''.join(extended)
flags |= FILEXFER_ATTR_EXTENDED
return struct.pack('!L', flags) + data
class FileTransferServer(FileTransferBase):
def __init__(self, data=None, avatar=None):
FileTransferBase.__init__(self)
self.client = ISFTPServer(avatar) # yay interfaces
self.openFiles = {}
self.openDirs = {}
def packet_INIT(self, data):
version ,= struct.unpack('!L', data[:4])
self.version = min(list(self.versions) + [version])
data = data[4:]
ext = {}
while data:
ext_name, data = getNS(data)
ext_data, data = getNS(data)
ext[ext_name] = ext_data
our_ext = self.client.gotVersion(version, ext)
our_ext_data = ""
for (k,v) in our_ext.items():
our_ext_data += NS(k) + NS(v)
self.sendPacket(FXP_VERSION, struct.pack('!L', self.version) + \
our_ext_data)
def packet_OPEN(self, data):
requestId = data[:4]
data = data[4:]
filename, data = getNS(data)
flags ,= struct.unpack('!L', data[:4])
data = data[4:]
attrs, data = self._parseAttributes(data)
assert data == '', 'still have data in OPEN: %s' % repr(data)
d = defer.maybeDeferred(self.client.openFile, filename, flags, attrs)
d.addCallback(self._cbOpenFile, requestId)
d.addErrback(self._ebStatus, requestId, "open failed")
def _cbOpenFile(self, fileObj, requestId):
fileId = str(hash(fileObj))
if fileId in self.openFiles:
raise KeyError, 'id already open'
self.openFiles[fileId] = fileObj
self.sendPacket(FXP_HANDLE, requestId + NS(fileId))
def packet_CLOSE(self, data):
requestId = data[:4]
data = data[4:]
handle, data = getNS(data)
assert data == '', 'still have data in CLOSE: %s' % repr(data)
if handle in self.openFiles:
fileObj = self.openFiles[handle]
d = defer.maybeDeferred(fileObj.close)
d.addCallback(self._cbClose, handle, requestId)
d.addErrback(self._ebStatus, requestId, "close failed")
elif handle in self.openDirs:
dirObj = self.openDirs[handle][0]
d = defer.maybeDeferred(dirObj.close)
d.addCallback(self._cbClose, handle, requestId, 1)
d.addErrback(self._ebStatus, requestId, "close failed")
else:
self._ebClose(failure.Failure(KeyError()), requestId)
def _cbClose(self, result, handle, requestId, isDir = 0):
if isDir:
del self.openDirs[handle]
else:
del self.openFiles[handle]
self._sendStatus(requestId, FX_OK, 'file closed')
def packet_READ(self, data):
requestId = data[:4]
data = data[4:]
handle, data = getNS(data)
(offset, length), data = struct.unpack('!QL', data[:12]), data[12:]
assert data == '', 'still have data in READ: %s' % repr(data)
if handle not in self.openFiles:
self._ebRead(failure.Failure(KeyError()), requestId)
else:
fileObj = self.openFiles[handle]
d = defer.maybeDeferred(fileObj.readChunk, offset, length)
d.addCallback(self._cbRead, requestId)
d.addErrback(self._ebStatus, requestId, "read failed")
def _cbRead(self, result, requestId):
if result == '': # python's read will return this for EOF
raise EOFError()
self.sendPacket(FXP_DATA, requestId + NS(result))
def packet_WRITE(self, data):
requestId = data[:4]
data = data[4:]
handle, data = getNS(data)
offset, = struct.unpack('!Q', data[:8])
data = data[8:]
writeData, data = getNS(data)
assert data == '', 'still have data in WRITE: %s' % repr(data)
if handle not in self.openFiles:
self._ebWrite(failure.Failure(KeyError()), requestId)
else:
fileObj = self.openFiles[handle]
d = defer.maybeDeferred(fileObj.writeChunk, offset, writeData)
d.addCallback(self._cbStatus, requestId, "write succeeded")
d.addErrback(self._ebStatus, requestId, "write failed")
def packet_REMOVE(self, data):
requestId = data[:4]
data = data[4:]
filename, data = getNS(data)
assert data == '', 'still have data in REMOVE: %s' % repr(data)
d = defer.maybeDeferred(self.client.removeFile, filename)
d.addCallback(self._cbStatus, requestId, "remove succeeded")
d.addErrback(self._ebStatus, requestId, "remove failed")
def packet_RENAME(self, data):
requestId = data[:4]
data = data[4:]
oldPath, data = getNS(data)
newPath, data = getNS(data)
assert data == '', 'still have data in RENAME: %s' % repr(data)
d = defer.maybeDeferred(self.client.renameFile, oldPath, newPath)
d.addCallback(self._cbStatus, requestId, "rename succeeded")
d.addErrback(self._ebStatus, requestId, "rename failed")
def packet_MKDIR(self, data):
requestId = data[:4]
data = data[4:]
path, data = getNS(data)
attrs, data = self._parseAttributes(data)
assert data == '', 'still have data in MKDIR: %s' % repr(data)
d = defer.maybeDeferred(self.client.makeDirectory, path, attrs)
d.addCallback(self._cbStatus, requestId, "mkdir succeeded")
d.addErrback(self._ebStatus, requestId, "mkdir failed")
def packet_RMDIR(self, data):
requestId = data[:4]
data = data[4:]
path, data = getNS(data)
assert data == '', 'still have data in RMDIR: %s' % repr(data)
d = defer.maybeDeferred(self.client.removeDirectory, path)
d.addCallback(self._cbStatus, requestId, "rmdir succeeded")
d.addErrback(self._ebStatus, requestId, "rmdir failed")
def packet_OPENDIR(self, data):
requestId = data[:4]
data = data[4:]
path, data = getNS(data)
assert data == '', 'still have data in OPENDIR: %s' % repr(data)
d = defer.maybeDeferred(self.client.openDirectory, path)
d.addCallback(self._cbOpenDirectory, requestId)
d.addErrback(self._ebStatus, requestId, "opendir failed")
def _cbOpenDirectory(self, dirObj, requestId):
handle = str(hash(dirObj))
if handle in self.openDirs:
raise KeyError, "already opened this directory"
self.openDirs[handle] = [dirObj, iter(dirObj)]
self.sendPacket(FXP_HANDLE, requestId + NS(handle))
def packet_READDIR(self, data):
requestId = data[:4]
data = data[4:]
handle, data = getNS(data)
assert data == '', 'still have data in READDIR: %s' % repr(data)
if handle not in self.openDirs:
self._ebStatus(failure.Failure(KeyError()), requestId)
else:
dirObj, dirIter = self.openDirs[handle]
d = defer.maybeDeferred(self._scanDirectory, dirIter, [])
d.addCallback(self._cbSendDirectory, requestId)
d.addErrback(self._ebStatus, requestId, "scan directory failed")
def _scanDirectory(self, dirIter, f):
while len(f) < 250:
try:
info = dirIter.next()
except StopIteration:
if not f:
raise EOFError
return f
if isinstance(info, defer.Deferred):
info.addCallback(self._cbScanDirectory, dirIter, f)
return
else:
f.append(info)
return f
def _cbScanDirectory(self, result, dirIter, f):
f.append(result)
return self._scanDirectory(dirIter, f)
def _cbSendDirectory(self, result, requestId):
data = ''
for (filename, longname, attrs) in result:
data += NS(filename)
data += NS(longname)
data += self._packAttributes(attrs)
self.sendPacket(FXP_NAME, requestId +
struct.pack('!L', len(result))+data)
def packet_STAT(self, data, followLinks = 1):
requestId = data[:4]
data = data[4:]
path, data = getNS(data)
assert data == '', 'still have data in STAT/LSTAT: %s' % repr(data)
d = defer.maybeDeferred(self.client.getAttrs, path, followLinks)
d.addCallback(self._cbStat, requestId)
d.addErrback(self._ebStatus, requestId, 'stat/lstat failed')
def packet_LSTAT(self, data):
self.packet_STAT(data, 0)
def packet_FSTAT(self, data):
requestId = data[:4]
data = data[4:]
handle, data = getNS(data)
assert data == '', 'still have data in FSTAT: %s' % repr(data)
if handle not in self.openFiles:
self._ebStatus(failure.Failure(KeyError('%s not in self.openFiles'
% handle)), requestId)
else:
fileObj = self.openFiles[handle]
d = defer.maybeDeferred(fileObj.getAttrs)
d.addCallback(self._cbStat, requestId)
d.addErrback(self._ebStatus, requestId, 'fstat failed')
def _cbStat(self, result, requestId):
data = requestId + self._packAttributes(result)
self.sendPacket(FXP_ATTRS, data)
def packet_SETSTAT(self, data):
requestId = data[:4]
data = data[4:]
path, data = getNS(data)
attrs, data = self._parseAttributes(data)
if data != '':
log.msg('WARN: still have data in SETSTAT: %s' % repr(data))
d = defer.maybeDeferred(self.client.setAttrs, path, attrs)
d.addCallback(self._cbStatus, requestId, 'setstat succeeded')
d.addErrback(self._ebStatus, requestId, 'setstat failed')
def packet_FSETSTAT(self, data):
requestId = data[:4]
data = data[4:]
handle, data = getNS(data)
attrs, data = self._parseAttributes(data)
assert data == '', 'still have data in FSETSTAT: %s' % repr(data)
if handle not in self.openFiles:
self._ebStatus(failure.Failure(KeyError()), requestId)
else:
fileObj = self.openFiles[handle]
d = defer.maybeDeferred(fileObj.setAttrs, attrs)
d.addCallback(self._cbStatus, requestId, 'fsetstat succeeded')
d.addErrback(self._ebStatus, requestId, 'fsetstat failed')
def packet_READLINK(self, data):
requestId = data[:4]
data = data[4:]
path, data = getNS(data)
assert data == '', 'still have data in READLINK: %s' % repr(data)
d = defer.maybeDeferred(self.client.readLink, path)
d.addCallback(self._cbReadLink, requestId)
d.addErrback(self._ebStatus, requestId, 'readlink failed')
def _cbReadLink(self, result, requestId):
self._cbSendDirectory([(result, '', {})], requestId)
def packet_SYMLINK(self, data):
requestId = data[:4]
data = data[4:]
linkPath, data = getNS(data)
targetPath, data = getNS(data)
d = defer.maybeDeferred(self.client.makeLink, linkPath, targetPath)
d.addCallback(self._cbStatus, requestId, 'symlink succeeded')
d.addErrback(self._ebStatus, requestId, 'symlink failed')
def packet_REALPATH(self, data):
requestId = data[:4]
data = data[4:]
path, data = getNS(data)
assert data == '', 'still have data in REALPATH: %s' % repr(data)
d = defer.maybeDeferred(self.client.realPath, path)
d.addCallback(self._cbReadLink, requestId) # same return format
d.addErrback(self._ebStatus, requestId, 'realpath failed')
def packet_EXTENDED(self, data):
requestId = data[:4]
data = data[4:]
extName, extData = getNS(data)
d = defer.maybeDeferred(self.client.extendedRequest, extName, extData)
d.addCallback(self._cbExtended, requestId)
d.addErrback(self._ebStatus, requestId, 'extended %s failed' % extName)
def _cbExtended(self, data, requestId):
self.sendPacket(FXP_EXTENDED_REPLY, requestId + data)
def _cbStatus(self, result, requestId, msg = "request succeeded"):
self._sendStatus(requestId, FX_OK, msg)
def _ebStatus(self, reason, requestId, msg = "request failed"):
code = FX_FAILURE
message = msg
if reason.type in (IOError, OSError):
if reason.value.errno == errno.ENOENT: # no such file
code = FX_NO_SUCH_FILE
message = reason.value.strerror
elif reason.value.errno == errno.EACCES: # permission denied
code = FX_PERMISSION_DENIED
message = reason.value.strerror
elif reason.value.errno == errno.EEXIST:
code = FX_FILE_ALREADY_EXISTS
else:
log.err(reason)
elif reason.type == EOFError: # EOF
code = FX_EOF
if reason.value.args:
message = reason.value.args[0]
elif reason.type == NotImplementedError:
code = FX_OP_UNSUPPORTED
if reason.value.args:
message = reason.value.args[0]
elif reason.type == SFTPError:
code = reason.value.code
message = reason.value.message
else:
log.err(reason)
self._sendStatus(requestId, code, message)
def _sendStatus(self, requestId, code, message, lang = ''):
"""
Helper method to send a FXP_STATUS message.
"""
data = requestId + struct.pack('!L', code)
data += NS(message)
data += NS(lang)
self.sendPacket(FXP_STATUS, data)
def connectionLost(self, reason):
"""
Clean all opened files and directories.
"""
for fileObj in self.openFiles.values():
fileObj.close()
self.openFiles = {}
for (dirObj, dirIter) in self.openDirs.values():
dirObj.close()
self.openDirs = {}
class FileTransferClient(FileTransferBase):
def __init__(self, extData = {}):
"""
@param extData: a dict of extended_name : extended_data items
to be sent to the server.
"""
FileTransferBase.__init__(self)
self.extData = {}
self.counter = 0
self.openRequests = {} # id -> Deferred
self.wasAFile = {} # Deferred -> 1 TERRIBLE HACK
def connectionMade(self):
data = struct.pack('!L', max(self.versions))
for k,v in self.extData.itervalues():
data += NS(k) + NS(v)
self.sendPacket(FXP_INIT, data)
def _sendRequest(self, msg, data):
data = struct.pack('!L', self.counter) + data
d = defer.Deferred()
self.openRequests[self.counter] = d
self.counter += 1
self.sendPacket(msg, data)
return d
def _parseRequest(self, data):
(id,) = struct.unpack('!L', data[:4])
d = self.openRequests[id]
del self.openRequests[id]
return d, data[4:]
def openFile(self, filename, flags, attrs):
"""
Open a file.
This method returns a L{Deferred} that is called back with an object
that provides the L{ISFTPFile} interface.
@param filename: a string representing the file to open.
@param flags: a integer of the flags to open the file with, ORed together.
The flags and their values are listed at the bottom of this file.
@param attrs: a list of attributes to open the file with. It is a
dictionary, consisting of 0 or more keys. The possible keys are::
size: the size of the file in bytes
uid: the user ID of the file as an integer
gid: the group ID of the file as an integer
permissions: the permissions of the file with as an integer.
the bit representation of this field is defined by POSIX.
atime: the access time of the file as seconds since the epoch.
mtime: the modification time of the file as seconds since the epoch.
ext_*: extended attributes. The server is not required to
understand this, but it may.
NOTE: there is no way to indicate text or binary files. it is up
to the SFTP client to deal with this.
"""
data = NS(filename) + struct.pack('!L', flags) + self._packAttributes(attrs)
d = self._sendRequest(FXP_OPEN, data)
self.wasAFile[d] = (1, filename) # HACK
return d
def removeFile(self, filename):
"""
Remove the given file.
This method returns a Deferred that is called back when it succeeds.
@param filename: the name of the file as a string.
"""
return self._sendRequest(FXP_REMOVE, NS(filename))
def renameFile(self, oldpath, newpath):
"""
Rename the given file.
This method returns a Deferred that is called back when it succeeds.
@param oldpath: the current location of the file.
@param newpath: the new file name.
"""
return self._sendRequest(FXP_RENAME, NS(oldpath)+NS(newpath))
def makeDirectory(self, path, attrs):
"""
Make a directory.
This method returns a Deferred that is called back when it is
created.
@param path: the name of the directory to create as a string.
@param attrs: a dictionary of attributes to create the directory
with. Its meaning is the same as the attrs in the openFile method.
"""
return self._sendRequest(FXP_MKDIR, NS(path)+self._packAttributes(attrs))
def removeDirectory(self, path):
"""
Remove a directory (non-recursively)
It is an error to remove a directory that has files or directories in
it.
This method returns a Deferred that is called back when it is removed.
@param path: the directory to remove.
"""
return self._sendRequest(FXP_RMDIR, NS(path))
def openDirectory(self, path):
"""
Open a directory for scanning.
This method returns a Deferred that is called back with an iterable
object that has a close() method.
The close() method is called when the client is finished reading
from the directory. At this point, the iterable will no longer
be used.
The iterable returns triples of the form (filename, longname, attrs)
or a Deferred that returns the same. The sequence must support
__getitem__, but otherwise may be any 'sequence-like' object.
filename is the name of the file relative to the directory.
logname is an expanded format of the filename. The recommended format
is:
-rwxr-xr-x 1 mjos staff 348911 Mar 25 14:29 t-filexfer
1234567890 123 12345678 12345678 12345678 123456789012
The first line is sample output, the second is the length of the field.
The fields are: permissions, link count, user owner, group owner,
size in bytes, modification time.
attrs is a dictionary in the format of the attrs argument to openFile.
@param path: the directory to open.
"""
d = self._sendRequest(FXP_OPENDIR, NS(path))
self.wasAFile[d] = (0, path)
return d
def getAttrs(self, path, followLinks=0):
"""
Return the attributes for the given path.
This method returns a dictionary in the same format as the attrs
argument to openFile or a Deferred that is called back with same.
@param path: the path to return attributes for as a string.
@param followLinks: a boolean. if it is True, follow symbolic links
and return attributes for the real path at the base. if it is False,
return attributes for the specified path.
"""
if followLinks: m = FXP_STAT
else: m = FXP_LSTAT
return self._sendRequest(m, NS(path))
def setAttrs(self, path, attrs):
"""
Set the attributes for the path.
This method returns when the attributes are set or a Deferred that is
called back when they are.
@param path: the path to set attributes for as a string.
@param attrs: a dictionary in the same format as the attrs argument to
openFile.
"""
data = NS(path) + self._packAttributes(attrs)
return self._sendRequest(FXP_SETSTAT, data)
def readLink(self, path):
"""
Find the root of a set of symbolic links.
This method returns the target of the link, or a Deferred that
returns the same.
@param path: the path of the symlink to read.
"""
d = self._sendRequest(FXP_READLINK, NS(path))
return d.addCallback(self._cbRealPath)
def makeLink(self, linkPath, targetPath):
"""
Create a symbolic link.
This method returns when the link is made, or a Deferred that
returns the same.
@param linkPath: the pathname of the symlink as a string
@param targetPath: the path of the target of the link as a string.
"""
return self._sendRequest(FXP_SYMLINK, NS(linkPath)+NS(targetPath))
def realPath(self, path):
"""
Convert any path to an absolute path.
This method returns the absolute path as a string, or a Deferred
that returns the same.
@param path: the path to convert as a string.
"""
d = self._sendRequest(FXP_REALPATH, NS(path))
return d.addCallback(self._cbRealPath)
def _cbRealPath(self, result):
name, longname, attrs = result[0]
return name
def extendedRequest(self, request, data):
"""
Make an extended request of the server.
The method returns a Deferred that is called back with
the result of the extended request.
@param request: the name of the extended request to make.
@param data: any other data that goes along with the request.
"""
return self._sendRequest(FXP_EXTENDED, NS(request) + data)
def packet_VERSION(self, data):
version, = struct.unpack('!L', data[:4])
data = data[4:]
d = {}
while data:
k, data = getNS(data)
v, data = getNS(data)
d[k]=v
self.version = version
self.gotServerVersion(version, d)
def packet_STATUS(self, data):
d, data = self._parseRequest(data)
code, = struct.unpack('!L', data[:4])
data = data[4:]
if len(data) >= 4:
msg, data = getNS(data)
if len(data) >= 4:
lang, data = getNS(data)
else:
lang = ''
else:
msg = ''
lang = ''
if code == FX_OK:
d.callback((msg, lang))
elif code == FX_EOF:
d.errback(EOFError(msg))
elif code == FX_OP_UNSUPPORTED:
d.errback(NotImplementedError(msg))
else:
d.errback(SFTPError(code, msg, lang))
def packet_HANDLE(self, data):
d, data = self._parseRequest(data)
isFile, name = self.wasAFile.pop(d)
if isFile:
cb = ClientFile(self, getNS(data)[0])
else:
cb = ClientDirectory(self, getNS(data)[0])
cb.name = name
d.callback(cb)
def packet_DATA(self, data):
d, data = self._parseRequest(data)
d.callback(getNS(data)[0])
def packet_NAME(self, data):
d, data = self._parseRequest(data)
count, = struct.unpack('!L', data[:4])
data = data[4:]
files = []
for i in range(count):
filename, data = getNS(data)
longname, data = getNS(data)
attrs, data = self._parseAttributes(data)
files.append((filename, longname, attrs))
d.callback(files)
def packet_ATTRS(self, data):
d, data = self._parseRequest(data)
d.callback(self._parseAttributes(data)[0])
def packet_EXTENDED_REPLY(self, data):
d, data = self._parseRequest(data)
d.callback(data)
def gotServerVersion(self, serverVersion, extData):
"""
Called when the client sends their version info.
@param otherVersion: an integer representing the version of the SFTP
protocol they are claiming.
@param extData: a dictionary of extended_name : extended_data items.
These items are sent by the client to indicate additional features.
"""
@implementer(ISFTPFile)
class ClientFile:
def __init__(self, parent, handle):
self.parent = parent
self.handle = NS(handle)
def close(self):
return self.parent._sendRequest(FXP_CLOSE, self.handle)
def readChunk(self, offset, length):
data = self.handle + struct.pack("!QL", offset, length)
return self.parent._sendRequest(FXP_READ, data)
def writeChunk(self, offset, chunk):
data = self.handle + struct.pack("!Q", offset) + NS(chunk)
return self.parent._sendRequest(FXP_WRITE, data)
def getAttrs(self):
return self.parent._sendRequest(FXP_FSTAT, self.handle)
def setAttrs(self, attrs):
data = self.handle + self.parent._packAttributes(attrs)
return self.parent._sendRequest(FXP_FSTAT, data)
class ClientDirectory:
def __init__(self, parent, handle):
self.parent = parent
self.handle = NS(handle)
self.filesCache = []
def read(self):
d = self.parent._sendRequest(FXP_READDIR, self.handle)
return d
def close(self):
return self.parent._sendRequest(FXP_CLOSE, self.handle)
def __iter__(self):
return self
def next(self):
if self.filesCache:
return self.filesCache.pop(0)
d = self.read()
d.addCallback(self._cbReadDir)
d.addErrback(self._ebReadDir)
return d
def _cbReadDir(self, names):
self.filesCache = names[1:]
return names[0]
def _ebReadDir(self, reason):
reason.trap(EOFError)
def _():
raise StopIteration
self.next = _
return reason
class SFTPError(Exception):
def __init__(self, errorCode, errorMessage, lang = ''):
Exception.__init__(self)
self.code = errorCode
self._message = errorMessage
self.lang = lang
def message(self):
"""
A string received over the network that explains the error to a human.
"""
# Python 2.6 deprecates assigning to the 'message' attribute of an
# exception. We define this read-only property here in order to
# prevent the warning about deprecation while maintaining backwards
# compatibility with object clients that rely on the 'message'
# attribute being set correctly. See bug #3897.
return self._message
message = property(message)
def __str__(self):
return 'SFTPError %s: %s' % (self.code, self.message)
FXP_INIT = 1
FXP_VERSION = 2
FXP_OPEN = 3
FXP_CLOSE = 4
FXP_READ = 5
FXP_WRITE = 6
FXP_LSTAT = 7
FXP_FSTAT = 8
FXP_SETSTAT = 9
FXP_FSETSTAT = 10
FXP_OPENDIR = 11
FXP_READDIR = 12
FXP_REMOVE = 13
FXP_MKDIR = 14
FXP_RMDIR = 15
FXP_REALPATH = 16
FXP_STAT = 17
FXP_RENAME = 18
FXP_READLINK = 19
FXP_SYMLINK = 20
FXP_STATUS = 101
FXP_HANDLE = 102
FXP_DATA = 103
FXP_NAME = 104
FXP_ATTRS = 105
FXP_EXTENDED = 200
FXP_EXTENDED_REPLY = 201
FILEXFER_ATTR_SIZE = 0x00000001
FILEXFER_ATTR_UIDGID = 0x00000002
FILEXFER_ATTR_OWNERGROUP = FILEXFER_ATTR_UIDGID
FILEXFER_ATTR_PERMISSIONS = 0x00000004
FILEXFER_ATTR_ACMODTIME = 0x00000008
FILEXFER_ATTR_EXTENDED = 0x80000000L
FILEXFER_TYPE_REGULAR = 1
FILEXFER_TYPE_DIRECTORY = 2
FILEXFER_TYPE_SYMLINK = 3
FILEXFER_TYPE_SPECIAL = 4
FILEXFER_TYPE_UNKNOWN = 5
FXF_READ = 0x00000001
FXF_WRITE = 0x00000002
FXF_APPEND = 0x00000004
FXF_CREAT = 0x00000008
FXF_TRUNC = 0x00000010
FXF_EXCL = 0x00000020
FXF_TEXT = 0x00000040
FX_OK = 0
FX_EOF = 1
FX_NO_SUCH_FILE = 2
FX_PERMISSION_DENIED = 3
FX_FAILURE = 4
FX_BAD_MESSAGE = 5
FX_NO_CONNECTION = 6
FX_CONNECTION_LOST = 7
FX_OP_UNSUPPORTED = 8
FX_FILE_ALREADY_EXISTS = 11
# http://tools.ietf.org/wg/secsh/draft-ietf-secsh-filexfer/ defines more
# useful error codes, but so far OpenSSH doesn't implement them. We use them
# internally for clarity, but for now define them all as FX_FAILURE to be
# compatible with existing software.
FX_NOT_A_DIRECTORY = FX_FAILURE
FX_FILE_IS_A_DIRECTORY = FX_FAILURE
# initialize FileTransferBase.packetTypes:
g = globals()
for name in g.keys():
if name.startswith('FXP_'):
value = g[name]
FileTransferBase.packetTypes[value] = name[4:]
del g, name, value

View File

@ -0,0 +1,236 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of the TCP forwarding, which allows
clients and servers to forward arbitrary TCP data across the connection.
Maintainer: Paul Swartz
"""
import struct
from twisted.internet import protocol, reactor
from twisted.internet.endpoints import HostnameEndpoint, connectProtocol
from twisted.python import log
import common, channel
class SSHListenForwardingFactory(protocol.Factory):
def __init__(self, connection, hostport, klass):
self.conn = connection
self.hostport = hostport # tuple
self.klass = klass
def buildProtocol(self, addr):
channel = self.klass(conn = self.conn)
client = SSHForwardingClient(channel)
channel.client = client
addrTuple = (addr.host, addr.port)
channelOpenData = packOpen_direct_tcpip(self.hostport, addrTuple)
self.conn.openChannel(channel, channelOpenData)
return client
class SSHListenForwardingChannel(channel.SSHChannel):
def channelOpen(self, specificData):
log.msg('opened forwarding channel %s' % self.id)
if len(self.client.buf)>1:
b = self.client.buf[1:]
self.write(b)
self.client.buf = ''
def openFailed(self, reason):
self.closed()
def dataReceived(self, data):
self.client.transport.write(data)
def eofReceived(self):
self.client.transport.loseConnection()
def closed(self):
if hasattr(self, 'client'):
log.msg('closing local forwarding channel %s' % self.id)
self.client.transport.loseConnection()
del self.client
class SSHListenClientForwardingChannel(SSHListenForwardingChannel):
name = 'direct-tcpip'
class SSHListenServerForwardingChannel(SSHListenForwardingChannel):
name = 'forwarded-tcpip'
class SSHConnectForwardingChannel(channel.SSHChannel):
"""
Channel used for handling server side forwarding request.
It acts as a client for the remote forwarding destination.
@ivar hostport: C{(host, port)} requested by client as forwarding
destination.
@type hostport: C{tupple} or a C{sequence}
@ivar client: Protocol connected to the forwarding destination.
@type client: L{protocol.Protocol}
@ivar clientBuf: Data received while forwarding channel is not yet
connected.
@type clientBuf: C{bytes}
@var _reactor: Reactor used for TCP connections.
@type _reactor: A reactor.
@ivar _channelOpenDeferred: Deferred used in testing to check the
result of C{channelOpen}.
@type _channelOpenDeferred: L{twisted.internet.defer.Deferred}
"""
_reactor = reactor
def __init__(self, hostport, *args, **kw):
channel.SSHChannel.__init__(self, *args, **kw)
self.hostport = hostport
self.client = None
self.clientBuf = ''
def channelOpen(self, specificData):
"""
See: L{channel.SSHChannel}
"""
log.msg("connecting to %s:%i" % self.hostport)
ep = HostnameEndpoint(
self._reactor, self.hostport[0], self.hostport[1])
d = connectProtocol(ep, SSHForwardingClient(self))
d.addCallbacks(self._setClient, self._close)
self._channelOpenDeferred = d
def _setClient(self, client):
"""
Called when the connection was established to the forwarding
destination.
@param client: Client protocol connected to the forwarding destination.
@type client: L{protocol.Protocol}
"""
self.client = client
log.msg("connected to %s:%i" % self.hostport)
if self.clientBuf:
self.client.transport.write(self.clientBuf)
self.clientBuf = None
if self.client.buf[1:]:
self.write(self.client.buf[1:])
self.client.buf = ''
def _close(self, reason):
"""
Called when failed to connect to the forwarding destination.
@param reason: Reason why connection failed.
@type reason: L{twisted.python.failure.Failure}
"""
log.msg("failed to connect: %s" % reason)
self.loseConnection()
def dataReceived(self, data):
"""
See: L{channel.SSHChannel}
"""
if self.client:
self.client.transport.write(data)
else:
self.clientBuf += data
def closed(self):
"""
See: L{channel.SSHChannel}
"""
if self.client:
log.msg('closed remote forwarding channel %s' % self.id)
if self.client.channel:
self.loseConnection()
self.client.transport.loseConnection()
del self.client
def openConnectForwardingClient(remoteWindow, remoteMaxPacket, data, avatar):
remoteHP, origHP = unpackOpen_direct_tcpip(data)
return SSHConnectForwardingChannel(remoteHP,
remoteWindow=remoteWindow,
remoteMaxPacket=remoteMaxPacket,
avatar=avatar)
class SSHForwardingClient(protocol.Protocol):
def __init__(self, channel):
self.channel = channel
self.buf = '\000'
def dataReceived(self, data):
if self.buf:
self.buf += data
else:
self.channel.write(data)
def connectionLost(self, reason):
if self.channel:
self.channel.loseConnection()
self.channel = None
def packOpen_direct_tcpip((connHost, connPort), (origHost, origPort)):
"""Pack the data suitable for sending in a CHANNEL_OPEN packet.
"""
conn = common.NS(connHost) + struct.pack('>L', connPort)
orig = common.NS(origHost) + struct.pack('>L', origPort)
return conn + orig
packOpen_forwarded_tcpip = packOpen_direct_tcpip
def unpackOpen_direct_tcpip(data):
"""Unpack the data to a usable format.
"""
connHost, rest = common.getNS(data)
connPort = int(struct.unpack('>L', rest[:4])[0])
origHost, rest = common.getNS(rest[4:])
origPort = int(struct.unpack('>L', rest[:4])[0])
return (connHost, connPort), (origHost, origPort)
unpackOpen_forwarded_tcpip = unpackOpen_direct_tcpip
def packGlobal_tcpip_forward((host, port)):
return common.NS(host) + struct.pack('>L', port)
def unpackGlobal_tcpip_forward(data):
host, rest = common.getNS(data)
port = int(struct.unpack('>L', rest[:4])[0])
return host, port
"""This is how the data -> eof -> close stuff /should/ work.
debug3: channel 1: waiting for connection
debug1: channel 1: connected
debug1: channel 1: read<=0 rfd 7 len 0
debug1: channel 1: read failed
debug1: channel 1: close_read
debug1: channel 1: input open -> drain
debug1: channel 1: ibuf empty
debug1: channel 1: send eof
debug1: channel 1: input drain -> closed
debug1: channel 1: rcvd eof
debug1: channel 1: output open -> drain
debug1: channel 1: obuf empty
debug1: channel 1: close_write
debug1: channel 1: output drain -> closed
debug1: channel 1: rcvd close
debug3: channel 1: will not send data after close
debug1: channel 1: send close
debug1: channel 1: is dead
"""

View File

@ -0,0 +1,857 @@
# -*- test-case-name: twisted.conch.test.test_keys -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Handling of RSA and DSA keys.
Maintainer: U{Paul Swartz}
"""
# base library imports
import base64
import itertools
from hashlib import md5, sha1
# external library imports
from Crypto.Cipher import DES3, AES
from Crypto.PublicKey import RSA, DSA
from Crypto import Util
from pyasn1.error import PyAsn1Error
from pyasn1.type import univ
from pyasn1.codec.ber import decoder as berDecoder
from pyasn1.codec.ber import encoder as berEncoder
# twisted
from twisted.python import randbytes
# sibling imports
from twisted.conch.ssh import common, sexpy
class BadKeyError(Exception):
"""
Raised when a key isn't what we expected from it.
XXX: we really need to check for bad keys
"""
class EncryptedKeyError(Exception):
"""
Raised when an encrypted key is presented to fromString/fromFile without
a password.
"""
class Key(object):
"""
An object representing a key. A key can be either a public or
private key. A public key can verify a signature; a private key can
create or verify a signature. To generate a string that can be stored
on disk, use the toString method. If you have a private key, but want
the string representation of the public key, use Key.public().toString().
@ivar keyObject: The C{Crypto.PublicKey.pubkey.pubkey} object that
operations are performed with.
"""
def fromFile(Class, filename, type=None, passphrase=None):
"""
Return a Key object corresponding to the data in filename. type
and passphrase function as they do in fromString.
"""
return Class.fromString(file(filename, 'rb').read(), type, passphrase)
fromFile = classmethod(fromFile)
def fromString(Class, data, type=None, passphrase=None):
"""
Return a Key object corresponding to the string data.
type is optionally the type of string, matching a _fromString_*
method. Otherwise, the _guessStringType() classmethod will be used
to guess a type. If the key is encrypted, passphrase is used as
the decryption key.
@type data: C{str}
@type type: C{None}/C{str}
@type passphrase: C{None}/C{str}
@rtype: C{Key}
"""
if type is None:
type = Class._guessStringType(data)
if type is None:
raise BadKeyError('cannot guess the type of %r' % data)
method = getattr(Class, '_fromString_%s' % type.upper(), None)
if method is None:
raise BadKeyError('no _fromString method for %s' % type)
if method.func_code.co_argcount == 2: # no passphrase
if passphrase:
raise BadKeyError('key not encrypted')
return method(data)
else:
return method(data, passphrase)
fromString = classmethod(fromString)
def _fromString_BLOB(Class, blob):
"""
Return a public key object corresponding to this public key blob.
The format of a RSA public key blob is::
string 'ssh-rsa'
integer e
integer n
The format of a DSA public key blob is::
string 'ssh-dss'
integer p
integer q
integer g
integer y
@type blob: C{str}
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
@raises BadKeyError: if the key type (the first string) is unknown.
"""
keyType, rest = common.getNS(blob)
if keyType == 'ssh-rsa':
e, n, rest = common.getMP(rest, 2)
return Class(RSA.construct((n, e)))
elif keyType == 'ssh-dss':
p, q, g, y, rest = common.getMP(rest, 4)
return Class(DSA.construct((y, g, p, q)))
else:
raise BadKeyError('unknown blob type: %s' % keyType)
_fromString_BLOB = classmethod(_fromString_BLOB)
def _fromString_PRIVATE_BLOB(Class, blob):
"""
Return a private key object corresponding to this private key blob.
The blob formats are as follows:
RSA keys::
string 'ssh-rsa'
integer n
integer e
integer d
integer u
integer p
integer q
DSA keys::
string 'ssh-dss'
integer p
integer q
integer g
integer y
integer x
@type blob: C{str}
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
@raises BadKeyError: if the key type (the first string) is unknown.
"""
keyType, rest = common.getNS(blob)
if keyType == 'ssh-rsa':
n, e, d, u, p, q, rest = common.getMP(rest, 6)
rsakey = Class(RSA.construct((n, e, d, p, q, u)))
return rsakey
elif keyType == 'ssh-dss':
p, q, g, y, x, rest = common.getMP(rest, 5)
dsakey = Class(DSA.construct((y, g, p, q, x)))
return dsakey
else:
raise BadKeyError('unknown blob type: %s' % keyType)
_fromString_PRIVATE_BLOB = classmethod(_fromString_PRIVATE_BLOB)
def _fromString_PUBLIC_OPENSSH(Class, data):
"""
Return a public key object corresponding to this OpenSSH public key
string. The format of an OpenSSH public key string is::
<key type> <base64-encoded public key blob>
@type data: C{str}
@return: A {Crypto.PublicKey.pubkey.pubkey} object
@raises BadKeyError: if the blob type is unknown.
"""
blob = base64.decodestring(data.split()[1])
return Class._fromString_BLOB(blob)
_fromString_PUBLIC_OPENSSH = classmethod(_fromString_PUBLIC_OPENSSH)
def _fromString_PRIVATE_OPENSSH(Class, data, passphrase):
"""
Return a private key object corresponding to this OpenSSH private key
string. If the key is encrypted, passphrase MUST be provided.
Providing a passphrase for an unencrypted key is an error.
The format of an OpenSSH private key string is::
-----BEGIN <key type> PRIVATE KEY-----
[Proc-Type: 4,ENCRYPTED
DEK-Info: DES-EDE3-CBC,<initialization value>]
<base64-encoded ASN.1 structure>
------END <key type> PRIVATE KEY------
The ASN.1 structure of a RSA key is::
(0, n, e, d, p, q)
The ASN.1 structure of a DSA key is::
(0, p, q, g, y, x)
@type data: C{str}
@type passphrase: C{str}
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
@raises BadKeyError: if
* a passphrase is provided for an unencrypted key
* the ASN.1 encoding is incorrect
@raises EncryptedKeyError: if
* a passphrase is not provided for an encrypted key
"""
lines = data.strip().split('\n')
kind = lines[0][11:14]
if lines[1].startswith('Proc-Type: 4,ENCRYPTED'): # encrypted key
if not passphrase:
raise EncryptedKeyError('Passphrase must be provided '
'for an encrypted key')
# Determine cipher and initialization vector
try:
_, cipher_iv_info = lines[2].split(' ', 1)
cipher, ivdata = cipher_iv_info.rstrip().split(',', 1)
except ValueError:
raise BadKeyError('invalid DEK-info %r' % lines[2])
if cipher == 'AES-128-CBC':
CipherClass = AES
keySize = 16
if len(ivdata) != 32:
raise BadKeyError('AES encrypted key with a bad IV')
elif cipher == 'DES-EDE3-CBC':
CipherClass = DES3
keySize = 24
if len(ivdata) != 16:
raise BadKeyError('DES encrypted key with a bad IV')
else:
raise BadKeyError('unknown encryption type %r' % cipher)
# extract keyData for decoding
iv = ''.join([chr(int(ivdata[i:i + 2], 16))
for i in range(0, len(ivdata), 2)])
ba = md5(passphrase + iv[:8]).digest()
bb = md5(ba + passphrase + iv[:8]).digest()
decKey = (ba + bb)[:keySize]
b64Data = base64.decodestring(''.join(lines[3:-1]))
keyData = CipherClass.new(decKey,
CipherClass.MODE_CBC,
iv).decrypt(b64Data)
removeLen = ord(keyData[-1])
keyData = keyData[:-removeLen]
else:
b64Data = ''.join(lines[1:-1])
keyData = base64.decodestring(b64Data)
try:
decodedKey = berDecoder.decode(keyData)[0]
except PyAsn1Error, e:
raise BadKeyError('Failed to decode key (Bad Passphrase?): %s' % e)
if kind == 'RSA':
if len(decodedKey) == 2: # alternate RSA key
decodedKey = decodedKey[0]
if len(decodedKey) < 6:
raise BadKeyError('RSA key failed to decode properly')
n, e, d, p, q = [long(value) for value in decodedKey[1:6]]
if p > q: # make p smaller than q
p, q = q, p
return Class(RSA.construct((n, e, d, p, q)))
elif kind == 'DSA':
p, q, g, y, x = [long(value) for value in decodedKey[1: 6]]
if len(decodedKey) < 6:
raise BadKeyError('DSA key failed to decode properly')
return Class(DSA.construct((y, g, p, q, x)))
_fromString_PRIVATE_OPENSSH = classmethod(_fromString_PRIVATE_OPENSSH)
def _fromString_PUBLIC_LSH(Class, data):
"""
Return a public key corresponding to this LSH public key string.
The LSH public key string format is::
<s-expression: ('public-key', (<key type>, (<name, <value>)+))>
The names for a RSA (key type 'rsa-pkcs1-sha1') key are: n, e.
The names for a DSA (key type 'dsa') key are: y, g, p, q.
@type data: C{str}
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
@raises BadKeyError: if the key type is unknown
"""
sexp = sexpy.parse(base64.decodestring(data[1:-1]))
assert sexp[0] == 'public-key'
kd = {}
for name, data in sexp[1][1:]:
kd[name] = common.getMP(common.NS(data))[0]
if sexp[1][0] == 'dsa':
return Class(DSA.construct((kd['y'], kd['g'], kd['p'], kd['q'])))
elif sexp[1][0] == 'rsa-pkcs1-sha1':
return Class(RSA.construct((kd['n'], kd['e'])))
else:
raise BadKeyError('unknown lsh key type %s' % sexp[1][0])
_fromString_PUBLIC_LSH = classmethod(_fromString_PUBLIC_LSH)
def _fromString_PRIVATE_LSH(Class, data):
"""
Return a private key corresponding to this LSH private key string.
The LSH private key string format is::
<s-expression: ('private-key', (<key type>, (<name>, <value>)+))>
The names for a RSA (key type 'rsa-pkcs1-sha1') key are: n, e, d, p, q.
The names for a DSA (key type 'dsa') key are: y, g, p, q, x.
@type data: C{str}
@return: a {Crypto.PublicKey.pubkey.pubkey} object
@raises BadKeyError: if the key type is unknown
"""
sexp = sexpy.parse(data)
assert sexp[0] == 'private-key'
kd = {}
for name, data in sexp[1][1:]:
kd[name] = common.getMP(common.NS(data))[0]
if sexp[1][0] == 'dsa':
assert len(kd) == 5, len(kd)
return Class(DSA.construct((kd['y'], kd['g'], kd['p'],
kd['q'], kd['x'])))
elif sexp[1][0] == 'rsa-pkcs1':
assert len(kd) == 8, len(kd)
if kd['p'] > kd['q']: # make p smaller than q
kd['p'], kd['q'] = kd['q'], kd['p']
return Class(RSA.construct((kd['n'], kd['e'], kd['d'],
kd['p'], kd['q'])))
else:
raise BadKeyError('unknown lsh key type %s' % sexp[1][0])
_fromString_PRIVATE_LSH = classmethod(_fromString_PRIVATE_LSH)
def _fromString_AGENTV3(Class, data):
"""
Return a private key object corresponsing to the Secure Shell Key
Agent v3 format.
The SSH Key Agent v3 format for a RSA key is::
string 'ssh-rsa'
integer e
integer d
integer n
integer u
integer p
integer q
The SSH Key Agent v3 format for a DSA key is::
string 'ssh-dss'
integer p
integer q
integer g
integer y
integer x
@type data: C{str}
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
@raises BadKeyError: if the key type (the first string) is unknown
"""
keyType, data = common.getNS(data)
if keyType == 'ssh-dss':
p, data = common.getMP(data)
q, data = common.getMP(data)
g, data = common.getMP(data)
y, data = common.getMP(data)
x, data = common.getMP(data)
return Class(DSA.construct((y, g, p, q, x)))
elif keyType == 'ssh-rsa':
e, data = common.getMP(data)
d, data = common.getMP(data)
n, data = common.getMP(data)
u, data = common.getMP(data)
p, data = common.getMP(data)
q, data = common.getMP(data)
return Class(RSA.construct((n, e, d, p, q, u)))
else:
raise BadKeyError("unknown key type %s" % keyType)
_fromString_AGENTV3 = classmethod(_fromString_AGENTV3)
def _guessStringType(Class, data):
"""
Guess the type of key in data. The types map to _fromString_*
methods.
"""
if data.startswith('ssh-'):
return 'public_openssh'
elif data.startswith('-----BEGIN'):
return 'private_openssh'
elif data.startswith('{'):
return 'public_lsh'
elif data.startswith('('):
return 'private_lsh'
elif data.startswith('\x00\x00\x00\x07ssh-'):
ignored, rest = common.getNS(data)
count = 0
while rest:
count += 1
ignored, rest = common.getMP(rest)
if count > 4:
return 'agentv3'
else:
return 'blob'
_guessStringType = classmethod(_guessStringType)
def __init__(self, keyObject):
"""
Initialize a PublicKey with a C{Crypto.PublicKey.pubkey.pubkey}
object.
@type keyObject: C{Crypto.PublicKey.pubkey.pubkey}
"""
self.keyObject = keyObject
def __eq__(self, other):
"""
Return True if other represents an object with the same key.
"""
if type(self) == type(other):
return self.type() == other.type() and self.data() == other.data()
else:
return NotImplemented
def __ne__(self, other):
"""
Return True if other represents anything other than this key.
"""
result = self.__eq__(other)
if result == NotImplemented:
return result
return not result
def __repr__(self):
"""
Return a pretty representation of this object.
"""
lines = [
'<%s %s (%s bits)' % (
self.type(),
self.isPublic() and 'Public Key' or 'Private Key',
self.keyObject.size())]
for k, v in sorted(self.data().items()):
lines.append('attr %s:' % k)
by = common.MP(v)[4:]
while by:
m = by[:15]
by = by[15:]
o = ''
for c in m:
o = o + '%02x:' % ord(c)
if len(m) < 15:
o = o[:-1]
lines.append('\t' + o)
lines[-1] = lines[-1] + '>'
return '\n'.join(lines)
def isPublic(self):
"""
Returns True if this Key is a public key.
"""
return not self.keyObject.has_private()
def public(self):
"""
Returns a version of this key containing only the public key data.
If this is a public key, this may or may not be the same object
as self.
"""
return Key(self.keyObject.publickey())
def fingerprint(self):
"""
Get the user presentation of the fingerprint of this L{Key}. As
described by U{RFC 4716 section
4<http://tools.ietf.org/html/rfc4716#section-4>}::
The fingerprint of a public key consists of the output of the MD5
message-digest algorithm [RFC1321]. The input to the algorithm is
the public key data as specified by [RFC4253]. (...) The output
of the (MD5) algorithm is presented to the user as a sequence of 16
octets printed as hexadecimal with lowercase letters and separated
by colons.
@since: 8.2
@return: the user presentation of this L{Key}'s fingerprint, as a
string.
@rtype: L{str}
"""
return ':'.join([x.encode('hex') for x in md5(self.blob()).digest()])
def type(self):
"""
Return the type of the object we wrap. Currently this can only be
'RSA' or 'DSA'.
"""
# the class is Crypto.PublicKey.<type>.<stuff we don't care about>
mod = self.keyObject.__class__.__module__
if mod.startswith('Crypto.PublicKey'):
type = mod.split('.')[2]
else:
raise RuntimeError('unknown type of object: %r' % self.keyObject)
if type in ('RSA', 'DSA'):
return type
else:
raise RuntimeError('unknown type of key: %s' % type)
def sshType(self):
"""
Return the type of the object we wrap as defined in the ssh protocol.
Currently this can only be 'ssh-rsa' or 'ssh-dss'.
"""
return {'RSA': 'ssh-rsa', 'DSA': 'ssh-dss'}[self.type()]
def data(self):
"""
Return the values of the public key as a dictionary.
@rtype: C{dict}
"""
keyData = {}
for name in self.keyObject.keydata:
value = getattr(self.keyObject, name, None)
if value is not None:
keyData[name] = value
return keyData
def blob(self):
"""
Return the public key blob for this key. The blob is the
over-the-wire format for public keys:
RSA keys::
string 'ssh-rsa'
integer e
integer n
DSA keys::
string 'ssh-dss'
integer p
integer q
integer g
integer y
@rtype: C{str}
"""
type = self.type()
data = self.data()
if type == 'RSA':
return (common.NS('ssh-rsa') + common.MP(data['e']) +
common.MP(data['n']))
elif type == 'DSA':
return (common.NS('ssh-dss') + common.MP(data['p']) +
common.MP(data['q']) + common.MP(data['g']) +
common.MP(data['y']))
def privateBlob(self):
"""
Return the private key blob for this key. The blob is the
over-the-wire format for private keys:
RSA keys::
string 'ssh-rsa'
integer n
integer e
integer d
integer u
integer p
integer q
DSA keys::
string 'ssh-dss'
integer p
integer q
integer g
integer y
integer x
"""
type = self.type()
data = self.data()
if type == 'RSA':
return (common.NS('ssh-rsa') + common.MP(data['n']) +
common.MP(data['e']) + common.MP(data['d']) +
common.MP(data['u']) + common.MP(data['p']) +
common.MP(data['q']))
elif type == 'DSA':
return (common.NS('ssh-dss') + common.MP(data['p']) +
common.MP(data['q']) + common.MP(data['g']) +
common.MP(data['y']) + common.MP(data['x']))
def toString(self, type, extra=None):
"""
Create a string representation of this key. If the key is a private
key and you want the represenation of its public key, use
C{key.public().toString()}. type maps to a _toString_* method.
@param type: The type of string to emit. Currently supported values
are C{'OPENSSH'}, C{'LSH'}, and C{'AGENTV3'}.
@type type: L{str}
@param extra: Any extra data supported by the selected format which
is not part of the key itself. For public OpenSSH keys, this is
a comment. For private OpenSSH keys, this is a passphrase to
encrypt with.
@type extra: L{str} or L{NoneType}
@rtype: L{str}
"""
method = getattr(self, '_toString_%s' % type.upper(), None)
if method is None:
raise BadKeyError('unknown type: %s' % type)
if method.func_code.co_argcount == 2:
return method(extra)
else:
return method()
def _toString_OPENSSH(self, extra):
"""
Return a public or private OpenSSH string. See
_fromString_PUBLIC_OPENSSH and _fromString_PRIVATE_OPENSSH for the
string formats. If extra is present, it represents a comment for a
public key, or a passphrase for a private key.
@param extra: Comment for a public key or passphrase for a
private key
@type extra: C{str}
@rtype: C{str}
"""
data = self.data()
if self.isPublic():
b64Data = base64.encodestring(self.blob()).replace('\n', '')
if not extra:
extra = ''
return ('%s %s %s' % (self.sshType(), b64Data, extra)).strip()
else:
lines = ['-----BEGIN %s PRIVATE KEY-----' % self.type()]
if self.type() == 'RSA':
p, q = data['p'], data['q']
objData = (0, data['n'], data['e'], data['d'], q, p,
data['d'] % (q - 1), data['d'] % (p - 1),
data['u'])
else:
objData = (0, data['p'], data['q'], data['g'], data['y'],
data['x'])
asn1Sequence = univ.Sequence()
for index, value in itertools.izip(itertools.count(), objData):
asn1Sequence.setComponentByPosition(index, univ.Integer(value))
asn1Data = berEncoder.encode(asn1Sequence)
if extra:
iv = randbytes.secureRandom(8)
hexiv = ''.join(['%02X' % ord(x) for x in iv])
lines.append('Proc-Type: 4,ENCRYPTED')
lines.append('DEK-Info: DES-EDE3-CBC,%s\n' % hexiv)
ba = md5(extra + iv).digest()
bb = md5(ba + extra + iv).digest()
encKey = (ba + bb)[:24]
padLen = 8 - (len(asn1Data) % 8)
asn1Data += (chr(padLen) * padLen)
asn1Data = DES3.new(encKey, DES3.MODE_CBC,
iv).encrypt(asn1Data)
b64Data = base64.encodestring(asn1Data).replace('\n', '')
lines += [b64Data[i:i + 64] for i in range(0, len(b64Data), 64)]
lines.append('-----END %s PRIVATE KEY-----' % self.type())
return '\n'.join(lines)
def _toString_LSH(self):
"""
Return a public or private LSH key. See _fromString_PUBLIC_LSH and
_fromString_PRIVATE_LSH for the key formats.
@rtype: C{str}
"""
data = self.data()
if self.isPublic():
if self.type() == 'RSA':
keyData = sexpy.pack([['public-key',
['rsa-pkcs1-sha1',
['n', common.MP(data['n'])[4:]],
['e', common.MP(data['e'])[4:]]]]])
elif self.type() == 'DSA':
keyData = sexpy.pack([['public-key',
['dsa',
['p', common.MP(data['p'])[4:]],
['q', common.MP(data['q'])[4:]],
['g', common.MP(data['g'])[4:]],
['y', common.MP(data['y'])[4:]]]]])
return '{' + base64.encodestring(keyData).replace('\n', '') + '}'
else:
if self.type() == 'RSA':
p, q = data['p'], data['q']
return sexpy.pack([['private-key',
['rsa-pkcs1',
['n', common.MP(data['n'])[4:]],
['e', common.MP(data['e'])[4:]],
['d', common.MP(data['d'])[4:]],
['p', common.MP(q)[4:]],
['q', common.MP(p)[4:]],
['a', common.MP(data['d'] % (q - 1))[4:]],
['b', common.MP(data['d'] % (p - 1))[4:]],
['c', common.MP(data['u'])[4:]]]]])
elif self.type() == 'DSA':
return sexpy.pack([['private-key',
['dsa',
['p', common.MP(data['p'])[4:]],
['q', common.MP(data['q'])[4:]],
['g', common.MP(data['g'])[4:]],
['y', common.MP(data['y'])[4:]],
['x', common.MP(data['x'])[4:]]]]])
def _toString_AGENTV3(self):
"""
Return a private Secure Shell Agent v3 key. See
_fromString_AGENTV3 for the key format.
@rtype: C{str}
"""
data = self.data()
if not self.isPublic():
if self.type() == 'RSA':
values = (data['e'], data['d'], data['n'], data['u'],
data['p'], data['q'])
elif self.type() == 'DSA':
values = (data['p'], data['q'], data['g'], data['y'],
data['x'])
return common.NS(self.sshType()) + ''.join(map(common.MP, values))
def sign(self, data):
"""
Returns a signature with this Key.
@type data: C{str}
@rtype: C{str}
"""
if self.type() == 'RSA':
digest = pkcs1Digest(data, self.keyObject.size() / 8)
signature = self.keyObject.sign(digest, '')[0]
ret = common.NS(Util.number.long_to_bytes(signature))
elif self.type() == 'DSA':
digest = sha1(data).digest()
randomBytes = randbytes.secureRandom(19)
sig = self.keyObject.sign(digest, randomBytes)
# SSH insists that the DSS signature blob be two 160-bit integers
# concatenated together. The sig[0], [1] numbers from obj.sign
# are just numbers, and could be any length from 0 to 160 bits.
# Make sure they are padded out to 160 bits (20 bytes each)
ret = common.NS(Util.number.long_to_bytes(sig[0], 20) +
Util.number.long_to_bytes(sig[1], 20))
return common.NS(self.sshType()) + ret
def verify(self, signature, data):
"""
Returns true if the signature for data is valid for this Key.
@type signature: C{str}
@type data: C{str}
@rtype: C{bool}
"""
if len(signature) == 40:
# DSA key with no padding
signatureType, signature = 'ssh-dss', common.NS(signature)
else:
signatureType, signature = common.getNS(signature)
if signatureType != self.sshType():
return False
if self.type() == 'RSA':
numbers = common.getMP(signature)
digest = pkcs1Digest(data, self.keyObject.size() / 8)
elif self.type() == 'DSA':
signature = common.getNS(signature)[0]
numbers = [Util.number.bytes_to_long(n) for n in signature[:20],
signature[20:]]
digest = sha1(data).digest()
return self.keyObject.verify(digest, numbers)
def objectType(obj):
"""
Return the SSH key type corresponding to a
C{Crypto.PublicKey.pubkey.pubkey} object.
@type obj: C{Crypto.PublicKey.pubkey.pubkey}
@rtype: C{str}
"""
keyDataMapping = {
('n', 'e', 'd', 'p', 'q'): 'ssh-rsa',
('n', 'e', 'd', 'p', 'q', 'u'): 'ssh-rsa',
('y', 'g', 'p', 'q', 'x'): 'ssh-dss'
}
try:
return keyDataMapping[tuple(obj.keydata)]
except (KeyError, AttributeError):
raise BadKeyError("invalid key object", obj)
def pkcs1Pad(data, messageLength):
"""
Pad out data to messageLength according to the PKCS#1 standard.
@type data: C{str}
@type messageLength: C{int}
"""
lenPad = messageLength - 2 - len(data)
return '\x01' + ('\xff' * lenPad) + '\x00' + data
def pkcs1Digest(data, messageLength):
"""
Create a message digest using the SHA1 hash algorithm according to the
PKCS#1 standard.
@type data: C{str}
@type messageLength: C{str}
"""
digest = sha1(data).digest()
return pkcs1Pad(ID_SHA1 + digest, messageLength)
def lenSig(obj):
"""
Return the length of the signature in bytes for a key object.
@type obj: C{Crypto.PublicKey.pubkey.pubkey}
@rtype: C{long}
"""
return obj.size() / 8
ID_SHA1 = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'

View File

@ -0,0 +1,48 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The parent class for all the SSH services. Currently implemented services
are ssh-userauth and ssh-connection.
Maintainer: Paul Swartz
"""
from twisted.python import log
class SSHService(log.Logger):
name = None # this is the ssh name for the service
protocolMessages = {} # these map #'s -> protocol names
transport = None # gets set later
def serviceStarted(self):
"""
called when the service is active on the transport.
"""
def serviceStopped(self):
"""
called when the service is stopped, either by the connection ending
or by another service being started
"""
def logPrefix(self):
return "SSHService %s on %s" % (self.name,
self.transport.transport.logPrefix())
def packetReceived(self, messageNum, packet):
"""
called when we receive a packet on the transport
"""
#print self.protocolMessages
if messageNum in self.protocolMessages:
messageType = self.protocolMessages[messageNum]
f = getattr(self,'ssh_%s' % messageType[4:],
None)
if f is not None:
return f(packet)
log.msg("couldn't handle %r" % messageNum)
log.msg(repr(packet))
self.transport.sendUnimplemented()

View File

@ -0,0 +1,349 @@
# -*- test-case-name: twisted.conch.test.test_session -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of SSHSession, which (by default)
allows access to a shell and a python interpreter over SSH.
Maintainer: Paul Swartz
"""
import struct
import signal
import sys
import os
from zope.interface import implementer
from twisted.internet import interfaces, protocol
from twisted.python import log
from twisted.conch.interfaces import ISession
from twisted.conch.ssh import common, channel
class SSHSession(channel.SSHChannel):
name = 'session'
def __init__(self, *args, **kw):
channel.SSHChannel.__init__(self, *args, **kw)
self.buf = ''
self.client = None
self.session = None
def request_subsystem(self, data):
subsystem, ignored= common.getNS(data)
log.msg('asking for subsystem "%s"' % subsystem)
client = self.avatar.lookupSubsystem(subsystem, data)
if client:
pp = SSHSessionProcessProtocol(self)
proto = wrapProcessProtocol(pp)
client.makeConnection(proto)
pp.makeConnection(wrapProtocol(client))
self.client = pp
return 1
else:
log.msg('failed to get subsystem')
return 0
def request_shell(self, data):
log.msg('getting shell')
if not self.session:
self.session = ISession(self.avatar)
try:
pp = SSHSessionProcessProtocol(self)
self.session.openShell(pp)
except:
log.deferr()
return 0
else:
self.client = pp
return 1
def request_exec(self, data):
if not self.session:
self.session = ISession(self.avatar)
f,data = common.getNS(data)
log.msg('executing command "%s"' % f)
try:
pp = SSHSessionProcessProtocol(self)
self.session.execCommand(pp, f)
except:
log.deferr()
return 0
else:
self.client = pp
return 1
def request_pty_req(self, data):
if not self.session:
self.session = ISession(self.avatar)
term, windowSize, modes = parseRequest_pty_req(data)
log.msg('pty request: %s %s' % (term, windowSize))
try:
self.session.getPty(term, windowSize, modes)
except:
log.err()
return 0
else:
return 1
def request_window_change(self, data):
if not self.session:
self.session = ISession(self.avatar)
winSize = parseRequest_window_change(data)
try:
self.session.windowChanged(winSize)
except:
log.msg('error changing window size')
log.err()
return 0
else:
return 1
def dataReceived(self, data):
if not self.client:
#self.conn.sendClose(self)
self.buf += data
return
self.client.transport.write(data)
def extReceived(self, dataType, data):
if dataType == connection.EXTENDED_DATA_STDERR:
if self.client and hasattr(self.client.transport, 'writeErr'):
self.client.transport.writeErr(data)
else:
log.msg('weird extended data: %s'%dataType)
def eofReceived(self):
if self.session:
self.session.eofReceived()
elif self.client:
self.conn.sendClose(self)
def closed(self):
if self.session:
self.session.closed()
elif self.client:
self.client.transport.loseConnection()
#def closeReceived(self):
# self.loseConnection() # don't know what to do with this
def loseConnection(self):
if self.client:
self.client.transport.loseConnection()
channel.SSHChannel.loseConnection(self)
class _ProtocolWrapper(protocol.ProcessProtocol):
"""
This class wraps a L{Protocol} instance in a L{ProcessProtocol} instance.
"""
def __init__(self, proto):
self.proto = proto
def connectionMade(self): self.proto.connectionMade()
def outReceived(self, data): self.proto.dataReceived(data)
def processEnded(self, reason): self.proto.connectionLost(reason)
class _DummyTransport:
def __init__(self, proto):
self.proto = proto
def dataReceived(self, data):
self.proto.transport.write(data)
def write(self, data):
self.proto.dataReceived(data)
def writeSequence(self, seq):
self.write(''.join(seq))
def loseConnection(self):
self.proto.connectionLost(protocol.connectionDone)
def wrapProcessProtocol(inst):
if isinstance(inst, protocol.Protocol):
return _ProtocolWrapper(inst)
else:
return inst
def wrapProtocol(proto):
return _DummyTransport(proto)
# SUPPORTED_SIGNALS is a list of signals that every session channel is supposed
# to accept. See RFC 4254
SUPPORTED_SIGNALS = ["ABRT", "ALRM", "FPE", "HUP", "ILL", "INT", "KILL",
"PIPE", "QUIT", "SEGV", "TERM", "USR1", "USR2"]
@implementer(interfaces.ITransport)
class SSHSessionProcessProtocol(protocol.ProcessProtocol):
"""I am both an L{IProcessProtocol} and an L{ITransport}.
I am a transport to the remote endpoint and a process protocol to the
local subsystem.
"""
# once initialized, a dictionary mapping signal values to strings
# that follow RFC 4254.
_signalValuesToNames = None
def __init__(self, session):
self.session = session
self.lostOutOrErrFlag = False
def connectionMade(self):
if self.session.buf:
self.transport.write(self.session.buf)
self.session.buf = None
def outReceived(self, data):
self.session.write(data)
def errReceived(self, err):
self.session.writeExtended(connection.EXTENDED_DATA_STDERR, err)
def outConnectionLost(self):
"""
EOF should only be sent when both STDOUT and STDERR have been closed.
"""
if self.lostOutOrErrFlag:
self.session.conn.sendEOF(self.session)
else:
self.lostOutOrErrFlag = True
def errConnectionLost(self):
"""
See outConnectionLost().
"""
self.outConnectionLost()
def connectionLost(self, reason = None):
self.session.loseConnection()
def _getSignalName(self, signum):
"""
Get a signal name given a signal number.
"""
if self._signalValuesToNames is None:
self._signalValuesToNames = {}
# make sure that the POSIX ones are the defaults
for signame in SUPPORTED_SIGNALS:
signame = 'SIG' + signame
sigvalue = getattr(signal, signame, None)
if sigvalue is not None:
self._signalValuesToNames[sigvalue] = signame
for k, v in signal.__dict__.items():
# Check for platform specific signals, ignoring Python specific
# SIG_DFL and SIG_IGN
if k.startswith('SIG') and not k.startswith('SIG_'):
if v not in self._signalValuesToNames:
self._signalValuesToNames[v] = k + '@' + sys.platform
return self._signalValuesToNames[signum]
def processEnded(self, reason=None):
"""
When we are told the process ended, try to notify the other side about
how the process ended using the exit-signal or exit-status requests.
Also, close the channel.
"""
if reason is not None:
err = reason.value
if err.signal is not None:
signame = self._getSignalName(err.signal)
if (getattr(os, 'WCOREDUMP', None) is not None and
os.WCOREDUMP(err.status)):
log.msg('exitSignal: %s (core dumped)' % (signame,))
coreDumped = 1
else:
log.msg('exitSignal: %s' % (signame,))
coreDumped = 0
self.session.conn.sendRequest(self.session, 'exit-signal',
common.NS(signame[3:]) + chr(coreDumped) +
common.NS('') + common.NS(''))
elif err.exitCode is not None:
log.msg('exitCode: %r' % (err.exitCode,))
self.session.conn.sendRequest(self.session, 'exit-status',
struct.pack('>L', err.exitCode))
self.session.loseConnection()
def getHost(self):
"""
Return the host from my session's transport.
"""
return self.session.conn.transport.getHost()
def getPeer(self):
"""
Return the peer from my session's transport.
"""
return self.session.conn.transport.getPeer()
def write(self, data):
self.session.write(data)
def writeSequence(self, seq):
self.session.write(''.join(seq))
def loseConnection(self):
self.session.loseConnection()
class SSHSessionClient(protocol.Protocol):
def dataReceived(self, data):
if self.transport:
self.transport.write(data)
# methods factored out to make live easier on server writers
def parseRequest_pty_req(data):
"""Parse the data from a pty-req request into usable data.
@returns: a tuple of (terminal type, (rows, cols, xpixel, ypixel), modes)
"""
term, rest = common.getNS(data)
cols, rows, xpixel, ypixel = struct.unpack('>4L', rest[: 16])
modes, ignored= common.getNS(rest[16:])
winSize = (rows, cols, xpixel, ypixel)
modes = [(ord(modes[i]), struct.unpack('>L', modes[i+1: i+5])[0]) for i in range(0, len(modes)-1, 5)]
return term, winSize, modes
def packRequest_pty_req(term, (rows, cols, xpixel, ypixel), modes):
"""Pack a pty-req request so that it is suitable for sending.
NOTE: modes must be packed before being sent here.
"""
termPacked = common.NS(term)
winSizePacked = struct.pack('>4L', cols, rows, xpixel, ypixel)
modesPacked = common.NS(modes) # depend on the client packing modes
return termPacked + winSizePacked + modesPacked
def parseRequest_window_change(data):
"""Parse the data from a window-change request into usuable data.
@returns: a tuple of (rows, cols, xpixel, ypixel)
"""
cols, rows, xpixel, ypixel = struct.unpack('>4L', data)
return rows, cols, xpixel, ypixel
def packRequest_window_change((rows, cols, xpixel, ypixel)):
"""Pack a window-change request so that it is suitable for sending.
"""
return struct.pack('>4L', cols, rows, xpixel, ypixel)
import connection

View File

@ -0,0 +1,42 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
def parse(s):
s = s.strip()
expr = []
while s:
if s[0] == '(':
newSexp = []
if expr:
expr[-1].append(newSexp)
expr.append(newSexp)
s = s[1:]
continue
if s[0] == ')':
aList = expr.pop()
s=s[1:]
if not expr:
assert not s
return aList
continue
i = 0
while s[i].isdigit(): i+=1
assert i
length = int(s[:i])
data = s[i+1:i+1+length]
expr[-1].append(data)
s=s[i+1+length:]
assert 0, "this should not happen"
def pack(sexp):
s = ""
for o in sexp:
if type(o) in (type(()), type([])):
s+='('
s+=pack(o)
s+=')'
else:
s+='%i:%s' % (len(o), o)
return s

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,838 @@
# -*- test-case-name: twisted.conch.test.test_userauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of the ssh-userauth service.
Currently implemented authentication types are public-key and password.
Maintainer: Paul Swartz
"""
import struct
from twisted.conch import error, interfaces
from twisted.conch.ssh import keys, transport, service
from twisted.conch.ssh.common import NS, getNS
from twisted.cred import credentials
from twisted.cred.error import UnauthorizedLogin
from twisted.internet import defer, reactor
from twisted.python import failure, log
class SSHUserAuthServer(service.SSHService):
"""
A service implementing the server side of the 'ssh-userauth' service. It
is used to authenticate the user on the other side as being able to access
this server.
@ivar name: the name of this service: 'ssh-userauth'
@type name: C{str}
@ivar authenticatedWith: a list of authentication methods that have
already been used.
@type authenticatedWith: C{list}
@ivar loginTimeout: the number of seconds we wait before disconnecting
the user for taking too long to authenticate
@type loginTimeout: C{int}
@ivar attemptsBeforeDisconnect: the number of failed login attempts we
allow before disconnecting.
@type attemptsBeforeDisconnect: C{int}
@ivar loginAttempts: the number of login attempts that have been made
@type loginAttempts: C{int}
@ivar passwordDelay: the number of seconds to delay when the user gives
an incorrect password
@type passwordDelay: C{int}
@ivar interfaceToMethod: a C{dict} mapping credential interfaces to
authentication methods. The server checks to see which of the
cred interfaces have checkers and tells the client that those methods
are valid for authentication.
@type interfaceToMethod: C{dict}
@ivar supportedAuthentications: A list of the supported authentication
methods.
@type supportedAuthentications: C{list} of C{str}
@ivar user: the last username the client tried to authenticate with
@type user: C{str}
@ivar method: the current authentication method
@type method: C{str}
@ivar nextService: the service the user wants started after authentication
has been completed.
@type nextService: C{str}
@ivar portal: the L{twisted.cred.portal.Portal} we are using for
authentication
@type portal: L{twisted.cred.portal.Portal}
@ivar clock: an object with a callLater method. Stubbed out for testing.
"""
name = 'ssh-userauth'
loginTimeout = 10 * 60 * 60
# 10 minutes before we disconnect them
attemptsBeforeDisconnect = 20
# 20 login attempts before a disconnect
passwordDelay = 1 # number of seconds to delay on a failed password
clock = reactor
interfaceToMethod = {
credentials.ISSHPrivateKey : 'publickey',
credentials.IUsernamePassword : 'password',
credentials.IPluggableAuthenticationModules : 'keyboard-interactive',
}
def serviceStarted(self):
"""
Called when the userauth service is started. Set up instance
variables, check if we should allow password/keyboard-interactive
authentication (only allow if the outgoing connection is encrypted) and
set up a login timeout.
"""
self.authenticatedWith = []
self.loginAttempts = 0
self.user = None
self.nextService = None
self._pamDeferred = None
self.portal = self.transport.factory.portal
self.supportedAuthentications = []
for i in self.portal.listCredentialsInterfaces():
if i in self.interfaceToMethod:
self.supportedAuthentications.append(self.interfaceToMethod[i])
if not self.transport.isEncrypted('in'):
# don't let us transport password in plaintext
if 'password' in self.supportedAuthentications:
self.supportedAuthentications.remove('password')
if 'keyboard-interactive' in self.supportedAuthentications:
self.supportedAuthentications.remove('keyboard-interactive')
self._cancelLoginTimeout = self.clock.callLater(
self.loginTimeout,
self.timeoutAuthentication)
def serviceStopped(self):
"""
Called when the userauth service is stopped. Cancel the login timeout
if it's still going.
"""
if self._cancelLoginTimeout:
self._cancelLoginTimeout.cancel()
self._cancelLoginTimeout = None
def timeoutAuthentication(self):
"""
Called when the user has timed out on authentication. Disconnect
with a DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE message.
"""
self._cancelLoginTimeout = None
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
'you took too long')
def tryAuth(self, kind, user, data):
"""
Try to authenticate the user with the given method. Dispatches to a
auth_* method.
@param kind: the authentication method to try.
@type kind: C{str}
@param user: the username the client is authenticating with.
@type user: C{str}
@param data: authentication specific data sent by the client.
@type data: C{str}
@return: A Deferred called back if the method succeeded, or erred back
if it failed.
@rtype: C{defer.Deferred}
"""
log.msg('%s trying auth %s' % (user, kind))
if kind not in self.supportedAuthentications:
return defer.fail(
error.ConchError('unsupported authentication, failing'))
kind = kind.replace('-', '_')
f = getattr(self,'auth_%s'%kind, None)
if f:
ret = f(data)
if not ret:
return defer.fail(
error.ConchError('%s return None instead of a Deferred'
% kind))
else:
return ret
return defer.fail(error.ConchError('bad auth type: %s' % kind))
def ssh_USERAUTH_REQUEST(self, packet):
"""
The client has requested authentication. Payload::
string user
string next service
string method
<authentication specific data>
@type packet: C{str}
"""
user, nextService, method, rest = getNS(packet, 3)
if user != self.user or nextService != self.nextService:
self.authenticatedWith = [] # clear auth state
self.user = user
self.nextService = nextService
self.method = method
d = self.tryAuth(method, user, rest)
if not d:
self._ebBadAuth(
failure.Failure(error.ConchError('auth returned none')))
return
d.addCallback(self._cbFinishedAuth)
d.addErrback(self._ebMaybeBadAuth)
d.addErrback(self._ebBadAuth)
return d
def _cbFinishedAuth(self, (interface, avatar, logout)):
"""
The callback when user has successfully been authenticated. For a
description of the arguments, see L{twisted.cred.portal.Portal.login}.
We start the service requested by the user.
"""
self.transport.avatar = avatar
self.transport.logoutFunction = logout
service = self.transport.factory.getService(self.transport,
self.nextService)
if not service:
raise error.ConchError('could not get next service: %s'
% self.nextService)
log.msg('%s authenticated with %s' % (self.user, self.method))
self.transport.sendPacket(MSG_USERAUTH_SUCCESS, '')
self.transport.setService(service())
def _ebMaybeBadAuth(self, reason):
"""
An intermediate errback. If the reason is
error.NotEnoughAuthentication, we send a MSG_USERAUTH_FAILURE, but
with the partial success indicator set.
@type reason: L{twisted.python.failure.Failure}
"""
reason.trap(error.NotEnoughAuthentication)
self.transport.sendPacket(MSG_USERAUTH_FAILURE,
NS(','.join(self.supportedAuthentications)) + '\xff')
def _ebBadAuth(self, reason):
"""
The final errback in the authentication chain. If the reason is
error.IgnoreAuthentication, we simply return; the authentication
method has sent its own response. Otherwise, send a failure message
and (if the method is not 'none') increment the number of login
attempts.
@type reason: L{twisted.python.failure.Failure}
"""
if reason.check(error.IgnoreAuthentication):
return
if self.method != 'none':
log.msg('%s failed auth %s' % (self.user, self.method))
if reason.check(UnauthorizedLogin):
log.msg('unauthorized login: %s' % reason.getErrorMessage())
elif reason.check(error.ConchError):
log.msg('reason: %s' % reason.getErrorMessage())
else:
log.msg(reason.getTraceback())
self.loginAttempts += 1
if self.loginAttempts > self.attemptsBeforeDisconnect:
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
'too many bad auths')
return
self.transport.sendPacket(
MSG_USERAUTH_FAILURE,
NS(','.join(self.supportedAuthentications)) + '\x00')
def auth_publickey(self, packet):
"""
Public key authentication. Payload::
byte has signature
string algorithm name
string key blob
[string signature] (if has signature is True)
Create a SSHPublicKey credential and verify it using our portal.
"""
hasSig = ord(packet[0])
algName, blob, rest = getNS(packet[1:], 2)
pubKey = keys.Key.fromString(blob)
signature = hasSig and getNS(rest)[0] or None
if hasSig:
b = (NS(self.transport.sessionID) + chr(MSG_USERAUTH_REQUEST) +
NS(self.user) + NS(self.nextService) + NS('publickey') +
chr(hasSig) + NS(pubKey.sshType()) + NS(blob))
c = credentials.SSHPrivateKey(self.user, algName, blob, b,
signature)
return self.portal.login(c, None, interfaces.IConchUser)
else:
c = credentials.SSHPrivateKey(self.user, algName, blob, None, None)
return self.portal.login(c, None,
interfaces.IConchUser).addErrback(self._ebCheckKey,
packet[1:])
def _ebCheckKey(self, reason, packet):
"""
Called back if the user did not sent a signature. If reason is
error.ValidPublicKey then this key is valid for the user to
authenticate with. Send MSG_USERAUTH_PK_OK.
"""
reason.trap(error.ValidPublicKey)
# if we make it here, it means that the publickey is valid
self.transport.sendPacket(MSG_USERAUTH_PK_OK, packet)
return failure.Failure(error.IgnoreAuthentication())
def auth_password(self, packet):
"""
Password authentication. Payload::
string password
Make a UsernamePassword credential and verify it with our portal.
"""
password = getNS(packet[1:])[0]
c = credentials.UsernamePassword(self.user, password)
return self.portal.login(c, None, interfaces.IConchUser).addErrback(
self._ebPassword)
def _ebPassword(self, f):
"""
If the password is invalid, wait before sending the failure in order
to delay brute-force password guessing.
"""
d = defer.Deferred()
self.clock.callLater(self.passwordDelay, d.callback, f)
return d
def auth_keyboard_interactive(self, packet):
"""
Keyboard interactive authentication. No payload. We create a
PluggableAuthenticationModules credential and authenticate with our
portal.
"""
if self._pamDeferred is not None:
self.transport.sendDisconnect(
transport.DISCONNECT_PROTOCOL_ERROR,
"only one keyboard interactive attempt at a time")
return defer.fail(error.IgnoreAuthentication())
c = credentials.PluggableAuthenticationModules(self.user,
self._pamConv)
return self.portal.login(c, None, interfaces.IConchUser)
def _pamConv(self, items):
"""
Convert a list of PAM authentication questions into a
MSG_USERAUTH_INFO_REQUEST. Returns a Deferred that will be called
back when the user has responses to the questions.
@param items: a list of 2-tuples (message, kind). We only care about
kinds 1 (password) and 2 (text).
@type items: C{list}
@rtype: L{defer.Deferred}
"""
resp = []
for message, kind in items:
if kind == 1: # password
resp.append((message, 0))
elif kind == 2: # text
resp.append((message, 1))
elif kind in (3, 4):
return defer.fail(error.ConchError(
'cannot handle PAM 3 or 4 messages'))
else:
return defer.fail(error.ConchError(
'bad PAM auth kind %i' % kind))
packet = NS('') + NS('') + NS('')
packet += struct.pack('>L', len(resp))
for prompt, echo in resp:
packet += NS(prompt)
packet += chr(echo)
self.transport.sendPacket(MSG_USERAUTH_INFO_REQUEST, packet)
self._pamDeferred = defer.Deferred()
return self._pamDeferred
def ssh_USERAUTH_INFO_RESPONSE(self, packet):
"""
The user has responded with answers to PAMs authentication questions.
Parse the packet into a PAM response and callback self._pamDeferred.
Payload::
uint32 numer of responses
string response 1
...
string response n
"""
d, self._pamDeferred = self._pamDeferred, None
try:
resp = []
numResps = struct.unpack('>L', packet[:4])[0]
packet = packet[4:]
while len(resp) < numResps:
response, packet = getNS(packet)
resp.append((response, 0))
if packet:
raise error.ConchError("%i bytes of extra data" % len(packet))
except:
d.errback(failure.Failure())
else:
d.callback(resp)
class SSHUserAuthClient(service.SSHService):
"""
A service implementing the client side of 'ssh-userauth'.
This service will try all authentication methods provided by the server,
making callbacks for more information when necessary.
@ivar name: the name of this service: 'ssh-userauth'
@type name: C{str}
@ivar preferredOrder: a list of authentication methods that should be used
first, in order of preference, if supported by the server
@type preferredOrder: C{list}
@ivar user: the name of the user to authenticate as
@type user: C{str}
@ivar instance: the service to start after authentication has finished
@type instance: L{service.SSHService}
@ivar authenticatedWith: a list of strings of authentication methods we've tried
@type authenticatedWith: C{list} of C{str}
@ivar triedPublicKeys: a list of public key objects that we've tried to
authenticate with
@type triedPublicKeys: C{list} of L{Key}
@ivar lastPublicKey: the last public key object we've tried to authenticate
with
@type lastPublicKey: L{Key}
"""
name = 'ssh-userauth'
preferredOrder = ['publickey', 'password', 'keyboard-interactive']
def __init__(self, user, instance):
self.user = user
self.instance = instance
def serviceStarted(self):
self.authenticatedWith = []
self.triedPublicKeys = []
self.lastPublicKey = None
self.askForAuth('none', '')
def askForAuth(self, kind, extraData):
"""
Send a MSG_USERAUTH_REQUEST.
@param kind: the authentication method to try.
@type kind: C{str}
@param extraData: method-specific data to go in the packet
@type extraData: C{str}
"""
self.lastAuth = kind
self.transport.sendPacket(MSG_USERAUTH_REQUEST, NS(self.user) +
NS(self.instance.name) + NS(kind) + extraData)
def tryAuth(self, kind):
"""
Dispatch to an authentication method.
@param kind: the authentication method
@type kind: C{str}
"""
kind = kind.replace('-', '_')
log.msg('trying to auth with %s' % (kind,))
f = getattr(self,'auth_%s' % (kind,), None)
if f:
return f()
def _ebAuth(self, ignored, *args):
"""
Generic callback for a failed authentication attempt. Respond by
asking for the list of accepted methods (the 'none' method)
"""
self.askForAuth('none', '')
def ssh_USERAUTH_SUCCESS(self, packet):
"""
We received a MSG_USERAUTH_SUCCESS. The server has accepted our
authentication, so start the next service.
"""
self.transport.setService(self.instance)
def ssh_USERAUTH_FAILURE(self, packet):
"""
We received a MSG_USERAUTH_FAILURE. Payload::
string methods
byte partial success
If partial success is C{True}, then the previous method succeeded but is
not sufficient for authentication. C{methods} is a comma-separated list
of accepted authentication methods.
We sort the list of methods by their position in C{self.preferredOrder},
removing methods that have already succeeded. We then call
C{self.tryAuth} with the most preferred method.
@param packet: the L{MSG_USERAUTH_FAILURE} payload.
@type packet: C{str}
@return: a L{defer.Deferred} that will be callbacked with C{None} as
soon as all authentication methods have been tried, or C{None} if no
more authentication methods are available.
@rtype: C{defer.Deferred} or C{None}
"""
canContinue, partial = getNS(packet)
partial = ord(partial)
if partial:
self.authenticatedWith.append(self.lastAuth)
def orderByPreference(meth):
"""
Invoked once per authentication method in order to extract a
comparison key which is then used for sorting.
@param meth: the authentication method.
@type meth: C{str}
@return: the comparison key for C{meth}.
@rtype: C{int}
"""
if meth in self.preferredOrder:
return self.preferredOrder.index(meth)
else:
# put the element at the end of the list.
return len(self.preferredOrder)
canContinue = sorted([meth for meth in canContinue.split(',')
if meth not in self.authenticatedWith],
key=orderByPreference)
log.msg('can continue with: %s' % canContinue)
return self._cbUserauthFailure(None, iter(canContinue))
def _cbUserauthFailure(self, result, iterator):
if result:
return
try:
method = iterator.next()
except StopIteration:
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
'no more authentication methods available')
else:
d = defer.maybeDeferred(self.tryAuth, method)
d.addCallback(self._cbUserauthFailure, iterator)
return d
def ssh_USERAUTH_PK_OK(self, packet):
"""
This message (number 60) can mean several different messages depending
on the current authentication type. We dispatch to individual methods
in order to handle this request.
"""
func = getattr(self, 'ssh_USERAUTH_PK_OK_%s' %
self.lastAuth.replace('-', '_'), None)
if func is not None:
return func(packet)
else:
self.askForAuth('none', '')
def ssh_USERAUTH_PK_OK_publickey(self, packet):
"""
This is MSG_USERAUTH_PK. Our public key is valid, so we create a
signature and try to authenticate with it.
"""
publicKey = self.lastPublicKey
b = (NS(self.transport.sessionID) + chr(MSG_USERAUTH_REQUEST) +
NS(self.user) + NS(self.instance.name) + NS('publickey') +
'\x01' + NS(publicKey.sshType()) + NS(publicKey.blob()))
d = self.signData(publicKey, b)
if not d:
self.askForAuth('none', '')
# this will fail, we'll move on
return
d.addCallback(self._cbSignedData)
d.addErrback(self._ebAuth)
def ssh_USERAUTH_PK_OK_password(self, packet):
"""
This is MSG_USERAUTH_PASSWD_CHANGEREQ. The password given has expired.
We ask for an old password and a new password, then send both back to
the server.
"""
prompt, language, rest = getNS(packet, 2)
self._oldPass = self._newPass = None
d = self.getPassword('Old Password: ')
d = d.addCallbacks(self._setOldPass, self._ebAuth)
d.addCallback(lambda ignored: self.getPassword(prompt))
d.addCallbacks(self._setNewPass, self._ebAuth)
def ssh_USERAUTH_PK_OK_keyboard_interactive(self, packet):
"""
This is MSG_USERAUTH_INFO_RESPONSE. The server has sent us the
questions it wants us to answer, so we ask the user and sent the
responses.
"""
name, instruction, lang, data = getNS(packet, 3)
numPrompts = struct.unpack('!L', data[:4])[0]
data = data[4:]
prompts = []
for i in range(numPrompts):
prompt, data = getNS(data)
echo = bool(ord(data[0]))
data = data[1:]
prompts.append((prompt, echo))
d = self.getGenericAnswers(name, instruction, prompts)
d.addCallback(self._cbGenericAnswers)
d.addErrback(self._ebAuth)
def _cbSignedData(self, signedData):
"""
Called back out of self.signData with the signed data. Send the
authentication request with the signature.
@param signedData: the data signed by the user's private key.
@type signedData: C{str}
"""
publicKey = self.lastPublicKey
self.askForAuth('publickey', '\x01' + NS(publicKey.sshType()) +
NS(publicKey.blob()) + NS(signedData))
def _setOldPass(self, op):
"""
Called back when we are choosing a new password. Simply store the old
password for now.
@param op: the old password as entered by the user
@type op: C{str}
"""
self._oldPass = op
def _setNewPass(self, np):
"""
Called back when we are choosing a new password. Get the old password
and send the authentication message with both.
@param np: the new password as entered by the user
@type np: C{str}
"""
op = self._oldPass
self._oldPass = None
self.askForAuth('password', '\xff' + NS(op) + NS(np))
def _cbGenericAnswers(self, responses):
"""
Called back when we are finished answering keyboard-interactive
questions. Send the info back to the server in a
MSG_USERAUTH_INFO_RESPONSE.
@param responses: a list of C{str} responses
@type responses: C{list}
"""
data = struct.pack('!L', len(responses))
for r in responses:
data += NS(r.encode('UTF8'))
self.transport.sendPacket(MSG_USERAUTH_INFO_RESPONSE, data)
def auth_publickey(self):
"""
Try to authenticate with a public key. Ask the user for a public key;
if the user has one, send the request to the server and return True.
Otherwise, return False.
@rtype: C{bool}
"""
d = defer.maybeDeferred(self.getPublicKey)
d.addBoth(self._cbGetPublicKey)
return d
def _cbGetPublicKey(self, publicKey):
if not isinstance(publicKey, keys.Key): # failure or None
publicKey = None
if publicKey is not None:
self.lastPublicKey = publicKey
self.triedPublicKeys.append(publicKey)
log.msg('using key of type %s' % publicKey.type())
self.askForAuth('publickey', '\x00' + NS(publicKey.sshType()) +
NS(publicKey.blob()))
return True
else:
return False
def auth_password(self):
"""
Try to authenticate with a password. Ask the user for a password.
If the user will return a password, return True. Otherwise, return
False.
@rtype: C{bool}
"""
d = self.getPassword()
if d:
d.addCallbacks(self._cbPassword, self._ebAuth)
return True
else: # returned None, don't do password auth
return False
def auth_keyboard_interactive(self):
"""
Try to authenticate with keyboard-interactive authentication. Send
the request to the server and return True.
@rtype: C{bool}
"""
log.msg('authing with keyboard-interactive')
self.askForAuth('keyboard-interactive', NS('') + NS(''))
return True
def _cbPassword(self, password):
"""
Called back when the user gives a password. Send the request to the
server.
@param password: the password the user entered
@type password: C{str}
"""
self.askForAuth('password', '\x00' + NS(password))
def signData(self, publicKey, signData):
"""
Sign the given data with the given public key.
By default, this will call getPrivateKey to get the private key,
then sign the data using Key.sign().
This method is factored out so that it can be overridden to use
alternate methods, such as a key agent.
@param publicKey: The public key object returned from L{getPublicKey}
@type publicKey: L{keys.Key}
@param signData: the data to be signed by the private key.
@type signData: C{str}
@return: a Deferred that's called back with the signature
@rtype: L{defer.Deferred}
"""
key = self.getPrivateKey()
if not key:
return
return key.addCallback(self._cbSignData, signData)
def _cbSignData(self, privateKey, signData):
"""
Called back when the private key is returned. Sign the data and
return the signature.
@param privateKey: the private key object
@type publicKey: L{keys.Key}
@param signData: the data to be signed by the private key.
@type signData: C{str}
@return: the signature
@rtype: C{str}
"""
return privateKey.sign(signData)
def getPublicKey(self):
"""
Return a public key for the user. If no more public keys are
available, return C{None}.
This implementation always returns C{None}. Override it in a
subclass to actually find and return a public key object.
@rtype: L{Key} or L{NoneType}
"""
return None
def getPrivateKey(self):
"""
Return a L{Deferred} that will be called back with the private key
object corresponding to the last public key from getPublicKey().
If the private key is not available, errback on the Deferred.
@rtype: L{Deferred} called back with L{Key}
"""
return defer.fail(NotImplementedError())
def getPassword(self, prompt = None):
"""
Return a L{Deferred} that will be called back with a password.
prompt is a string to display for the password, or None for a generic
'user@hostname's password: '.
@type prompt: C{str}/C{None}
@rtype: L{defer.Deferred}
"""
return defer.fail(NotImplementedError())
def getGenericAnswers(self, name, instruction, prompts):
"""
Returns a L{Deferred} with the responses to the promopts.
@param name: The name of the authentication currently in progress.
@param instruction: Describes what the authentication wants.
@param prompts: A list of (prompt, echo) pairs, where prompt is a
string to display and echo is a boolean indicating whether the
user's response should be echoed as they type it.
"""
return defer.fail(NotImplementedError())
MSG_USERAUTH_REQUEST = 50
MSG_USERAUTH_FAILURE = 51
MSG_USERAUTH_SUCCESS = 52
MSG_USERAUTH_BANNER = 53
MSG_USERAUTH_INFO_RESPONSE = 61
MSG_USERAUTH_PK_OK = 60
messages = {}
for k, v in locals().items():
if k[:4]=='MSG_':
messages[v] = k
SSHUserAuthServer.protocolMessages = messages
SSHUserAuthClient.protocolMessages = messages
del messages
del v
# Doubles, not included in the protocols' mappings
MSG_USERAUTH_PASSWD_CHANGEREQ = 60
MSG_USERAUTH_INFO_REQUEST = 60

View File

@ -0,0 +1,95 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Asynchronous local terminal input handling
@author: Jp Calderone
"""
import os, tty, sys, termios
from twisted.internet import reactor, stdio, protocol, defer
from twisted.python import failure, reflect, log
from twisted.conch.insults.insults import ServerProtocol
from twisted.conch.manhole import ColoredManhole
class UnexpectedOutputError(Exception):
pass
class TerminalProcessProtocol(protocol.ProcessProtocol):
def __init__(self, proto):
self.proto = proto
self.onConnection = defer.Deferred()
def connectionMade(self):
self.proto.makeConnection(self)
self.onConnection.callback(None)
self.onConnection = None
def write(self, bytes):
self.transport.write(bytes)
def outReceived(self, bytes):
self.proto.dataReceived(bytes)
def errReceived(self, bytes):
self.transport.loseConnection()
if self.proto is not None:
self.proto.connectionLost(failure.Failure(UnexpectedOutputError(bytes)))
self.proto = None
def childConnectionLost(self, childFD):
if self.proto is not None:
self.proto.childConnectionLost(childFD)
def processEnded(self, reason):
if self.proto is not None:
self.proto.connectionLost(reason)
self.proto = None
class ConsoleManhole(ColoredManhole):
"""
A manhole protocol specifically for use with L{stdio.StandardIO}.
"""
def connectionLost(self, reason):
"""
When the connection is lost, there is nothing more to do. Stop the
reactor so that the process can exit.
"""
reactor.stop()
def runWithProtocol(klass):
fd = sys.__stdin__.fileno()
oldSettings = termios.tcgetattr(fd)
tty.setraw(fd)
try:
p = ServerProtocol(klass)
stdio.StandardIO(p)
reactor.run()
finally:
termios.tcsetattr(fd, termios.TCSANOW, oldSettings)
os.write(fd, "\r\x1bc\r")
def main(argv=None):
log.startLogging(file('child.log', 'w'))
if argv is None:
argv = sys.argv[1:]
if argv:
klass = reflect.namedClass(argv[0])
else:
klass = ConsoleManhole
runWithProtocol(klass)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,93 @@
# -*- test-case-name: twisted.conch.test.test_tap -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Support module for making SSH servers with twistd.
"""
from twisted.conch import unix
from twisted.conch import checkers as conch_checkers
from twisted.conch.openssh_compat import factory
from twisted.cred import portal, checkers, strcred
from twisted.python import usage
from twisted.application import strports
try:
from twisted.cred import pamauth
except ImportError:
pamauth = None
class Options(usage.Options, strcred.AuthOptionMixin):
synopsis = "[-i <interface>] [-p <port>] [-d <dir>] "
longdesc = ("Makes a Conch SSH server. If no authentication methods are "
"specified, the default authentication methods are UNIX passwords, "
"SSH public keys, and PAM if it is available. If --auth options are "
"passed, only the measures specified will be used.")
optParameters = [
["interface", "i", "", "local interface to which we listen"],
["port", "p", "tcp:22", "Port on which to listen"],
["data", "d", "/etc", "directory to look for host keys in"],
["moduli", "", None, "directory to look for moduli in "
"(if different from --data)"]
]
compData = usage.Completions(
optActions={"data": usage.CompleteDirs(descr="data directory"),
"moduli": usage.CompleteDirs(descr="moduli directory"),
"interface": usage.CompleteNetInterfaces()}
)
def __init__(self, *a, **kw):
usage.Options.__init__(self, *a, **kw)
# call the default addCheckers (for backwards compatibility) that will
# be used if no --auth option is provided - note that conch's
# UNIXPasswordDatabase is used, instead of twisted.plugins.cred_unix's
# checker
super(Options, self).addChecker(conch_checkers.UNIXPasswordDatabase())
super(Options, self).addChecker(conch_checkers.SSHPublicKeyChecker(
conch_checkers.UNIXAuthorizedKeysFiles()))
if pamauth is not None:
super(Options, self).addChecker(
checkers.PluggableAuthenticationModulesChecker())
self._usingDefaultAuth = True
def addChecker(self, checker):
"""
Add the checker specified. If any checkers are added, the default
checkers are automatically cleared and the only checkers will be the
specified one(s).
"""
if self._usingDefaultAuth:
self['credCheckers'] = []
self['credInterfaces'] = {}
self._usingDefaultAuth = False
super(Options, self).addChecker(checker)
def makeService(config):
"""
Construct a service for operating a SSH server.
@param config: An L{Options} instance specifying server options, including
where server keys are stored and what authentication methods to use.
@return: An L{IService} provider which contains the requested SSH server.
"""
t = factory.OpenSSHFactory()
r = unix.UnixSSHRealm()
t.portal = portal.Portal(r, config.get('credCheckers', []))
t.dataRoot = config['data']
t.moduliRoot = config['moduli'] or config['data']
port = config['port']
if config['interface']:
# Add warning here
port += ':interface=' + config['interface']
return strports.service(port, t)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
'conch tests'

View File

@ -0,0 +1,208 @@
# -*- test-case-name: twisted.conch.test.test_keys -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Data used by test_keys as well as others.
"""
RSAData = {
'n':long('1062486685755247411169438309495398947372127791189432809481'
'382072971106157632182084539383569281493520117634129557550415277'
'516685881326038852354459895734875625093273594925884531272867425'
'864910490065695876046999646807138717162833156501L'),
'e':35L,
'd':long('6678487739032983727350755088256793383481946116047863373882'
'973030104095847973715959961839578340816412167985957218887914482'
'713602371850869127033494910375212470664166001439410214474266799'
'85974425203903884190893469297150446322896587555L'),
'q':long('3395694744258061291019136154000709371890447462086362702627'
'9704149412726577280741108645721676968699696898960891593323L'),
'p':long('3128922844292337321766351031842562691837301298995834258844'
'4720539204069737532863831050930719431498338835415515173887L')}
DSAData = {
'y':long('2300663509295750360093768159135720439490120577534296730713'
'348508834878775464483169644934425336771277908527130096489120714'
'610188630979820723924744291603865L'),
'g':long('4451569990409370769930903934104221766858515498655655091803'
'866645719060300558655677517139568505649468378587802312867198352'
'1161998270001677664063945776405L'),
'p':long('7067311773048598659694590252855127633397024017439939353776'
'608320410518694001356789646664502838652272205440894335303988504'
'978724817717069039110940675621677L'),
'q':1184501645189849666738820838619601267690550087703L,
'x':863951293559205482820041244219051653999559962819L}
publicRSA_openssh = ("ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBE"
"vLi8DVPrJ3/c9k2I/Az64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYL"
"h5KmRpslkYHRivcJSkbh/C+BR3utDS555mV comment")
privateRSA_openssh = """-----BEGIN RSA PRIVATE KEY-----
MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
-----END RSA PRIVATE KEY-----"""
# some versions of OpenSSH generate these (slightly different keys)
privateRSA_openssh_alternate = """-----BEGIN RSA PRIVATE KEY-----
MIIBzjCCAcgCAQACYQCvMnHw5g6cmbN/i18ES8uLwNU+snf9z2TYj8DPrh/GMd/2
KbJEluLG1CGUf2V82NQjH7guaskflA1GwWmitwcMo5PBNNguHkqZGmyWRgdGK9wl
KRuH8L4FHe60NLnnmZUCASMCYG4ftVWX6+1n7SuZbuzB7ahNUtbz1mUGBN/lVJ/M
iQA8m2eH7GWgq81vZZCKl5BNxiGPqI3YWYZDtYGxtNdfLCIKYcElikcStJr4ehEc
SqiLdcSRCTu+BMpF2VeKDSfLIwIxANyfa9mYIVYRjelfA50K05NuE3dBPIVPAHD9
BVT/vD0Jv4P2l39kEJEE/qJnR1RCawIxAMtKS9BAR+hFUvfHrwwgbUMNtjmU+dql
5QMGdoMk64ihVaKo3hI7d0mSiqlx0gKT/wIwS6Rffc3CSWUatmm4GJX/Zb9XIZK1
qgx1Lg2bbZmCXhH4hQQWr1WCBdXTpWU9BvIzAjAXO7DkmaHRZwIq8j/kIPaLUgYy
d20DC6Uk6su3N2tgEnAv2MjsJA2iAh55w918ozMCMQC0c5dLUBCjF7OoR/E6FHZS
0TgqzxIUNMGoVEwpNYCgOLjw+kzEwoWr24eCutzr2yowAA==
------END RSA PRIVATE KEY------"""
# encrypted with the passphrase 'encrypted'
privateRSA_openssh_encrypted = """-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: DES-EDE3-CBC,FFFFFFFFFFFFFFFF
30qUR7DYY/rpVJu159paRM1mUqt/IMibfEMTKWSjNhCVD21hskftZCJROw/WgIFt
ncusHpJMkjgwEpho0KyKilcC7zxjpunTex24Meb5pCdXCrYft8AyUkRdq3dugMqT
4nuWuWxziluBhKQ2M9tPGcEOeulU4vVjceZt2pZhZQVBf08o3XUv5/7RYd24M9md
WIo+5zdj2YQkI6xMFTP954O/X32ME1KQt98wgNEy6mxhItbvf00mH3woALwEKP3v
PSMxxtx3VKeDKd9YTOm1giKkXZUf91vZWs0378tUBrU4U5qJxgryTjvvVKOtofj6
4qQy6+r6M6wtwVlXBgeRm2gBPvL3nv6MsROp3E6ztBd/e7A8fSec+UTq3ko/EbGP
0QG+IG5tg8FsdITxQ9WAIITZL3Rc6hA5Ymx1VNhySp3iSiso8Jof27lku4pyuvRV
ko/B3N2H7LnQrGV0GyrjeYocW/qZh/PCsY48JBFhlNQexn2mn44AJW3y5xgbhvKA
3mrmMD1hD17ZvZxi4fPHjbuAyM1vFqhQx63eT9ijbwJ91svKJl5O5MIv41mCRonm
hxvOXw8S0mjSasyofptzzQCtXxFLQigXbpQBltII+Ys=
-----END RSA PRIVATE KEY-----"""
# encrypted with the passphrase 'testxp'. NB: this key was generated by
# OpenSSH, so it doesn't use the same key data as the other keys here.
privateRSA_openssh_encrypted_aes = """-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,0673309A6ACCAB4B77DEE1C1E536AC26
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
-----END RSA PRIVATE KEY-----"""
publicRSA_lsh = ("{KDEwOnB1YmxpYy1rZXkoMTQ6cnNhLXBrY3MxLXNoYTEoMTpuOTc6AK8yc"
"fDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW4sbUIZR/ZXzY1CMfuC5qyR+UDUbB"
"aaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fwvgUd7rQ0ueeZlSkoMTplMTojKSkp}")
privateRSA_lsh = ("(11:private-key(9:rsa-pkcs1(1:n97:\x00\xaf2q\xf0\xe6\x0e"
"\x9c\x99\xb3\x7f\x8b_\x04K\xcb\x8b\xc0\xd5>\xb2w\xfd\xcfd\xd8\x8f\xc0\xcf"
"\xae\x1f\xc61\xdf\xf6)\xb2D\x96\xe2\xc6\xd4!\x94\x7fe|\xd8\xd4#\x1f\xb8.j"
"\xc9\x1f\x94\rF\xc1i\xa2\xb7\x07\x0c\xa3\x93\xc14\xd8.\x1eJ\x99\x1al\x96F"
"\x07F+\xdc%)\x1b\x87\xf0\xbe\x05\x1d\xee\xb44\xb9\xe7\x99\x95)(1:e1:#)(1:d9"
"6:n\x1f\xb5U\x97\xeb\xedg\xed+\x99n\xec\xc1\xed\xa8MR\xd6\xf3\xd6e\x06\x04"
"\xdf\xe5T\x9f\xcc\x89\x00<\x9bg\x87\xece\xa0\xab\xcdoe\x90\x8a\x97\x90M\xc6"
'!\x8f\xa8\x8d\xd8Y\x86C\xb5\x81\xb1\xb4\xd7_,"\na\xc1%\x8aG\x12\xb4\x9a\xf8'
"z\x11\x1cJ\xa8\x8bu\xc4\x91\t;\xbe\x04\xcaE\xd9W\x8a\r\'\xcb#)(1:p49:\x00"
"\xdc\x9fk\xd9\x98!V\x11\x8d\xe9_\x03\x9d\n\xd3\x93n\x13wA<\x85O\x00p\xfd"
"\x05T\xff\xbc=\t\xbf\x83\xf6\x97\x7fd\x10\x91\x04\xfe\xa2gGTBk)(1:q49:\x00"
"\xcbJK\xd0@G\xe8ER\xf7\xc7\xaf\x0c mC\r\xb69\x94\xf9\xda\xa5\xe5\x03\x06v"
"\x83$\xeb\x88\xa1U\xa2\xa8\xde\x12;wI\x92\x8a\xa9q\xd2\x02\x93\xff)(1:a48:K"
"\xa4_}\xcd\xc2Ie\x1a\xb6i\xb8\x18\x95\xffe\xbfW!\x92\xb5\xaa\x0cu.\r\x9bm"
"\x99\x82^\x11\xf8\x85\x04\x16\xafU\x82\x05\xd5\xd3\xa5e=\x06\xf23)(1:b48:"
"\x17;\xb0\xe4\x99\xa1\xd1g\x02*\xf2?\xe4 \xf6\x8bR\x062wm\x03\x0b\xa5$\xea"
"\xcb\xb77k`\x12p/\xd8\xc8\xec$\r\xa2\x02\x1ey\xc3\xdd|\xa33)(1:c49:\x00\xb4"
"s\x97KP\x10\xa3\x17\xb3\xa8G\xf1:\x14vR\xd18*\xcf\x12\x144\xc1\xa8TL)5\x80"
"\xa08\xb8\xf0\xfaL\xc4\xc2\x85\xab\xdb\x87\x82\xba\xdc\xeb\xdb*)))")
privateRSA_agentv3 = ("\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00`"
"n\x1f\xb5U\x97\xeb\xedg\xed+\x99n\xec\xc1\xed\xa8MR\xd6\xf3\xd6e\x06\x04"
"\xdf\xe5T\x9f\xcc\x89\x00<\x9bg\x87\xece\xa0\xab\xcdoe\x90\x8a\x97\x90M\xc6"
'!\x8f\xa8\x8d\xd8Y\x86C\xb5\x81\xb1\xb4\xd7_,"\na\xc1%\x8aG\x12\xb4\x9a\xf8'
"z\x11\x1cJ\xa8\x8bu\xc4\x91\t;\xbe\x04\xcaE\xd9W\x8a\r\'\xcb#\x00\x00\x00a"
"\x00\xaf2q\xf0\xe6\x0e\x9c\x99\xb3\x7f\x8b_\x04K\xcb\x8b\xc0\xd5>\xb2w\xfd"
"\xcfd\xd8\x8f\xc0\xcf\xae\x1f\xc61\xdf\xf6)\xb2D\x96\xe2\xc6\xd4!\x94\x7fe|"
"\xd8\xd4#\x1f\xb8.j\xc9\x1f\x94\rF\xc1i\xa2\xb7\x07\x0c\xa3\x93\xc14\xd8."
"\x1eJ\x99\x1al\x96F\x07F+\xdc%)\x1b\x87\xf0\xbe\x05\x1d\xee\xb44\xb9\xe7"
"\x99\x95\x00\x00\x001\x00\xb4s\x97KP\x10\xa3\x17\xb3\xa8G\xf1:\x14vR\xd18*"
"\xcf\x12\x144\xc1\xa8TL)5\x80\xa08\xb8\xf0\xfaL\xc4\xc2\x85\xab\xdb\x87\x82"
"\xba\xdc\xeb\xdb*\x00\x00\x001\x00\xcbJK\xd0@G\xe8ER\xf7\xc7\xaf\x0c mC\r"
"\xb69\x94\xf9\xda\xa5\xe5\x03\x06v\x83$\xeb\x88\xa1U\xa2\xa8\xde\x12;wI\x92"
"\x8a\xa9q\xd2\x02\x93\xff\x00\x00\x001\x00\xdc\x9fk\xd9\x98!V\x11\x8d\xe9_"
"\x03\x9d\n\xd3\x93n\x13wA<\x85O\x00p\xfd\x05T\xff\xbc=\t\xbf\x83\xf6\x97"
"\x7fd\x10\x91\x04\xfe\xa2gGTBk")
publicDSA_openssh = ("ssh-dss AAAAB3NzaC1kc3MAAABBAIbwTOSsZ7Bl7U1KyMNqV13Tu7"
"yRAtTr70PVI3QnfrPumf2UzCgpL1ljbKxSfAi05XvrE/1vfCFAsFYXRZLhQy0AAAAVAM965Akmo"
"6eAi7K+k9qDR4TotFAXAAAAQADZlpTW964haQWS4vC063NGdldT6xpUGDcDRqbm90CoPEa2RmNO"
"uOqi8lnbhYraEzypYH3K4Gzv/bxCBnKtHRUAAABAK+1osyWBS0+P90u/rAuko6chZ98thUSY2kL"
"SHp6hLKyy2bjnT29h7haELE+XHfq2bM9fckDx2FLOSIJzy83VmQ== comment")
privateDSA_openssh = """-----BEGIN DSA PRIVATE KEY-----
MIH4AgEAAkEAhvBM5KxnsGXtTUrIw2pXXdO7vJEC1OvvQ9UjdCd+s+6Z/ZTMKCkv
WWNsrFJ8CLTle+sT/W98IUCwVhdFkuFDLQIVAM965Akmo6eAi7K+k9qDR4TotFAX
AkAA2ZaU1veuIWkFkuLwtOtzRnZXU+saVBg3A0am5vdAqDxGtkZjTrjqovJZ24WK
2hM8qWB9yuBs7/28QgZyrR0VAkAr7WizJYFLT4/3S7+sC6SjpyFn3y2FRJjaQtIe
nqEsrLLZuOdPb2HuFoQsT5cd+rZsz19yQPHYUs5IgnPLzdWZAhUAl1TqdmlAG/b4
nnVchGiO9sML8MM=
-----END DSA PRIVATE KEY-----"""
publicDSA_lsh = ("{KDEwOnB1YmxpYy1rZXkoMzpkc2EoMTpwNjU6AIbwTOSsZ7Bl7U1KyMNqV"
"13Tu7yRAtTr70PVI3QnfrPumf2UzCgpL1ljbKxSfAi05XvrE/1vfCFAsFYXRZLhQy0pKDE6cTIx"
"OgDPeuQJJqOngIuyvpPag0eE6LRQFykoMTpnNjQ6ANmWlNb3riFpBZLi8LTrc0Z2V1PrGlQYNwN"
"Gpub3QKg8RrZGY0646qLyWduFitoTPKlgfcrgbO/9vEIGcq0dFSkoMTp5NjQ6K+1osyWBS0+P90"
"u/rAuko6chZ98thUSY2kLSHp6hLKyy2bjnT29h7haELE+XHfq2bM9fckDx2FLOSIJzy83VmSkpK"
"Q==}")
privateDSA_lsh = ("(11:private-key(3:dsa(1:p65:\x00\x86\xf0L\xe4\xacg\xb0e"
"\xedMJ\xc8\xc3jW]\xd3\xbb\xbc\x91\x02\xd4\xeb\xefC\xd5#t'~\xb3\xee\x99\xfd"
"\x94\xcc()/Ycl\xacR|\x08\xb4\xe5{\xeb\x13\xfdo|!@\xb0V\x17E\x92\xe1C-)(1:q2"
"1:\x00\xcfz\xe4\t&\xa3\xa7\x80\x8b\xb2\xbe\x93\xda\x83G\x84\xe8\xb4P\x17)(1"
":g64:\x00\xd9\x96\x94\xd6\xf7\xae!i\x05\x92\xe2\xf0\xb4\xebsFvWS\xeb\x1aT"
"\x187\x03F\xa6\xe6\xf7@\xa8<F\xb6FcN\xb8\xea\xa2\xf2Y\xdb\x85\x8a\xda\x13<"
"\xa9`}\xca\xe0l\xef\xfd\xbcB\x06r\xad\x1d\x15)(1:y64:+\xedh\xb3%\x81KO\x8f"
"\xf7K\xbf\xac\x0b\xa4\xa3\xa7!g\xdf-\x85D\x98\xdaB\xd2\x1e\x9e\xa1,\xac\xb2"
"\xd9\xb8\xe7Ooa\xee\x16\x84,O\x97\x1d\xfa\xb6l\xcf_r@\xf1\xd8R\xceH\x82s"
"\xcb\xcd\xd5\x99)(1:x21:\x00\x97T\xeavi@\x1b\xf6\xf8\x9eu\\\x84h\x8e\xf6"
"\xc3\x0b\xf0\xc3)))")
privateDSA_agentv3 = ("\x00\x00\x00\x07ssh-dss\x00\x00\x00A\x00\x86\xf0L\xe4"
"\xacg\xb0e\xedMJ\xc8\xc3jW]\xd3\xbb\xbc\x91\x02\xd4\xeb\xefC\xd5#t'~\xb3"
"\xee\x99\xfd\x94\xcc()/Ycl\xacR|\x08\xb4\xe5{\xeb\x13\xfdo|!@\xb0V\x17E\x92"
"\xe1C-\x00\x00\x00\x15\x00\xcfz\xe4\t&\xa3\xa7\x80\x8b\xb2\xbe\x93\xda\x83G"
"\x84\xe8\xb4P\x17\x00\x00\x00@\x00\xd9\x96\x94\xd6\xf7\xae!i\x05\x92\xe2"
"\xf0\xb4\xebsFvWS\xeb\x1aT\x187\x03F\xa6\xe6\xf7@\xa8<F\xb6FcN\xb8\xea\xa2"
"\xf2Y\xdb\x85\x8a\xda\x13<\xa9`}\xca\xe0l\xef\xfd\xbcB\x06r\xad\x1d\x15\x00"
"\x00\x00@+\xedh\xb3%\x81KO\x8f\xf7K\xbf\xac\x0b\xa4\xa3\xa7!g\xdf-\x85D\x98"
"\xdaB\xd2\x1e\x9e\xa1,\xac\xb2\xd9\xb8\xe7Ooa\xee\x16\x84,O\x97\x1d\xfa\xb6"
"l\xcf_r@\xf1\xd8R\xceH\x82s\xcb\xcd\xd5\x99\x00\x00\x00\x15\x00\x97T\xeavi@"
"\x1b\xf6\xf8\x9eu\\\x84h\x8e\xf6\xc3\x0b\xf0\xc3")
__all__ = ['DSAData', 'RSAData', 'privateDSA_agentv3', 'privateDSA_lsh',
'privateDSA_openssh', 'privateRSA_agentv3', 'privateRSA_lsh',
'privateRSA_openssh', 'publicDSA_lsh', 'publicDSA_openssh',
'publicRSA_lsh', 'publicRSA_openssh', 'privateRSA_openssh_alternate']

View File

@ -0,0 +1,49 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{SSHTransportAddrress} in ssh/address.py
"""
from twisted.trial import unittest
from twisted.internet.address import IPv4Address
from twisted.internet.test.test_address import AddressTestCaseMixin
from twisted.conch.ssh.address import SSHTransportAddress
class SSHTransportAddressTests(unittest.TestCase, AddressTestCaseMixin):
"""
L{twisted.conch.ssh.address.SSHTransportAddress} is what Conch transports
use to represent the other side of the SSH connection. This tests the
basic functionality of that class (string representation, comparison, &c).
"""
def _stringRepresentation(self, stringFunction):
"""
The string representation of C{SSHTransportAddress} should be
"SSHTransportAddress(<stringFunction on address>)".
"""
addr = self.buildAddress()
stringValue = stringFunction(addr)
addressValue = stringFunction(addr.address)
self.assertEqual(stringValue,
"SSHTransportAddress(%s)" % addressValue)
def buildAddress(self):
"""
Create an arbitrary new C{SSHTransportAddress}. A new instance is
created for each call, but always for the same address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.1", 22))
def buildDifferentAddress(self):
"""
Like C{buildAddress}, but with a different fixed address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.2", 22))

View File

@ -0,0 +1,399 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.ssh.agent}.
"""
import struct
from twisted.internet import reactor
from twisted.internet.interfaces import ITLSTransport
from twisted.trial import unittest
if not ITLSTransport.providedBy(reactor):
iosim = None
else:
from twisted.test import iosim
try:
import Crypto.Cipher.DES3
except ImportError:
Crypto = None
try:
import pyasn1
except ImportError:
pyasn1 = None
if Crypto and pyasn1:
from twisted.conch.ssh import keys, agent
else:
keys = agent = None
from twisted.conch.test import keydata
from twisted.conch.error import ConchError, MissingKeyStoreError
class StubFactory(object):
"""
Mock factory that provides the keys attribute required by the
SSHAgentServerProtocol
"""
def __init__(self):
self.keys = {}
class AgentTestBase(unittest.TestCase):
"""
Tests for SSHAgentServer/Client.
"""
if iosim is None:
skip = "iosim requires SSL, but SSL is not available"
elif agent is None or keys is None:
skip = "Cannot run without PyCrypto or PyASN1"
def setUp(self):
# wire up our client <-> server
self.client, self.server, self.pump = iosim.connectedServerAndClient(
agent.SSHAgentServer, agent.SSHAgentClient)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
# pub/priv keys of each kind
self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
class ServerProtocolContractWithFactoryTests(AgentTestBase):
"""
The server protocol is stateful and so uses its factory to track state
across requests. This test asserts that the protocol raises if its factory
doesn't provide the necessary storage for that state.
"""
def test_factorySuppliesKeyStorageForServerProtocol(self):
# need a message to send into the server
msg = struct.pack('!LB',1, agent.AGENTC_REQUEST_IDENTITIES)
del self.server.factory.__dict__['keys']
self.assertRaises(MissingKeyStoreError,
self.server.dataReceived, msg)
class UnimplementedVersionOneServerTests(AgentTestBase):
"""
Tests for methods with no-op implementations on the server. We need these
for clients, such as openssh, that try v1 methods before going to v2.
Because the client doesn't expose these operations with nice method names,
we invoke sendRequest directly with an op code.
"""
def test_agentc_REQUEST_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA identities request
"""
d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, '')
self.pump.flush()
def _cb(packet):
self.assertEqual(
agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0]))
return d.addCallback(_cb)
def test_agentc_REMOVE_RSA_IDENTITY(self):
"""
assert that we get the correct op code for an RSA remove identity request
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, '')
self.pump.flush()
return d.addCallback(self.assertEqual, '')
def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA remove all identities
request.
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, '')
self.pump.flush()
return d.addCallback(self.assertEqual, '')
if agent is not None:
class CorruptServer(agent.SSHAgentServer):
"""
A misbehaving server that returns bogus response op codes so that we can
verify that our callbacks that deal with these op codes handle such
miscreants.
"""
def agentc_REQUEST_IDENTITIES(self, data):
self.sendResponse(254, '')
def agentc_SIGN_REQUEST(self, data):
self.sendResponse(254, '')
class ClientWithBrokenServerTests(AgentTestBase):
"""
verify error handling code in the client using a misbehaving server
"""
def setUp(self):
AgentTestBase.setUp(self)
self.client, self.server, self.pump = iosim.connectedServerAndClient(
CorruptServer, agent.SSHAgentClient)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
def test_signDataCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.signData} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for data signing requests.
"""
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentitiesCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for identity requests.
"""
d = self.client.requestIdentities()
self.pump.flush()
return self.assertFailure(d, ConchError)
class AgentKeyAdditionTests(AgentTestBase):
"""
Test adding different flavors of keys to an agent.
"""
def test_addRSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that omitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.rsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual('', serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that omitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.dsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual('', serverKey[1])
return d.addCallback(_check)
def test_addRSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.rsaPrivate.privateBlob(), comment='My special key')
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual('My special key', serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.dsaPrivate.privateBlob(), comment='My special key')
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual('My special key', serverKey[1])
return d.addCallback(_check)
class AgentClientFailureTests(AgentTestBase):
def test_agentFailure(self):
"""
verify that the client raises ConchError on AGENT_FAILURE
"""
d = self.client.sendRequest(254, '')
self.pump.flush()
return self.assertFailure(d, ConchError)
class AgentIdentityRequestsTests(AgentTestBase):
"""
Test operations against a server with identities already loaded.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate, 'a comment')
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate, 'another comment')
def test_signDataRSA(self):
"""
Sign data with an RSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
self.pump.flush()
def _check(sig):
expected = self.rsaPrivate.sign("John Hancock")
self.assertEqual(expected, sig)
self.assertTrue(self.rsaPublic.verify(sig, "John Hancock"))
return d.addCallback(_check)
def test_signDataDSA(self):
"""
Sign data with a DSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.dsaPublic.blob(), "John Hancock")
self.pump.flush()
def _check(sig):
# Cannot do this b/c DSA uses random numbers when signing
# expected = self.dsaPrivate.sign("John Hancock")
# self.assertEqual(expected, sig)
self.assertTrue(self.dsaPublic.verify(sig, "John Hancock"))
return d.addCallback(_check)
def test_signDataRSAErrbackOnUnknownBlob(self):
"""
Assert that we get an errback if we try to sign data using a key that
wasn't added.
"""
del self.server.factory.keys[self.rsaPublic.blob()]
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentities(self):
"""
Assert that we get all of the keys/comments that we add when we issue a
request for all identities.
"""
d = self.client.requestIdentities()
self.pump.flush()
def _check(keyt):
expected = {}
expected[self.dsaPublic.blob()] = 'a comment'
expected[self.rsaPublic.blob()] = 'another comment'
received = {}
for k in keyt:
received[keys.Key.fromString(k[0], type='blob').blob()] = k[1]
self.assertEqual(expected, received)
return d.addCallback(_check)
class AgentKeyRemovalTests(AgentTestBase):
"""
Test support for removing keys in a remote server.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate, 'a comment')
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate, 'another comment')
def test_removeRSAIdentity(self):
"""
Assert that we can remove an RSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.rsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeDSAIdentity(self):
"""
Assert that we can remove a DSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.dsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeAllIdentities(self):
"""
Assert that we can remove all identities.
"""
d = self.client.removeAllIdentities()
self.pump.flush()
def _check(ignored):
self.assertEqual(0, len(self.server.factory.keys))
return d.addCallback(_check)

View File

@ -0,0 +1,992 @@
# -*- test-case-name: twisted.conch.test.test_cftp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE file for details.
"""
Tests for L{twisted.conch.scripts.cftp}.
"""
import locale
import time, sys, os, operator, getpass, struct
from StringIO import StringIO
from twisted.conch.test.test_ssh import Crypto, pyasn1
_reason = None
if Crypto and pyasn1:
try:
from twisted.conch import unix
from twisted.conch.scripts import cftp
from twisted.conch.scripts.cftp import SSHSession
from twisted.conch.test.test_filetransfer import FileTransferForTestAvatar
except ImportError as e:
unix = None
_reason = str(e)
del e
else:
unix = None
from twisted.python.fakepwd import UserDatabase
from twisted.trial.unittest import TestCase
from twisted.cred import portal
from twisted.internet import reactor, protocol, interfaces, defer, error
from twisted.internet.utils import getProcessOutputAndValue
from twisted.python import log
from twisted.conch import ls
from twisted.test.proto_helpers import StringTransport
from twisted.internet.task import Clock
from twisted.conch.test import test_ssh, test_conch
from twisted.conch.test.test_filetransfer import SFTPTestBase
from twisted.conch.test.test_filetransfer import FileTransferTestAvatar
from twisted.conch.test.test_conch import FakeStdio
class SSHSessionTests(TestCase):
"""
Tests for L{twisted.conch.scripts.cftp.SSHSession}.
"""
def test_eofReceived(self):
"""
L{twisted.conch.scripts.cftp.SSHSession.eofReceived} loses the write
half of its stdio connection.
"""
stdio = FakeStdio()
channel = SSHSession()
channel.stdio = stdio
channel.eofReceived()
self.assertTrue(stdio.writeConnLost)
class ListingTests(TestCase):
"""
Tests for L{lsLine}, the function which generates an entry for a file or
directory in an SFTP I{ls} command's output.
"""
if getattr(time, 'tzset', None) is None:
skip = "Cannot test timestamp formatting code without time.tzset"
def setUp(self):
"""
Patch the L{ls} module's time function so the results of L{lsLine} are
deterministic.
"""
self.now = 123456789
def fakeTime():
return self.now
self.patch(ls, 'time', fakeTime)
# Make sure that the timezone ends up the same after these tests as
# it was before.
if 'TZ' in os.environ:
self.addCleanup(operator.setitem, os.environ, 'TZ', os.environ['TZ'])
self.addCleanup(time.tzset)
else:
def cleanup():
# os.environ.pop is broken! Don't use it! Ever! Or die!
try:
del os.environ['TZ']
except KeyError:
pass
time.tzset()
self.addCleanup(cleanup)
def _lsInTimezone(self, timezone, stat):
"""
Call L{ls.lsLine} after setting the timezone to C{timezone} and return
the result.
"""
# Set the timezone to a well-known value so the timestamps are
# predictable.
os.environ['TZ'] = timezone
time.tzset()
return ls.lsLine('foo', stat)
def test_oldFile(self):
"""
A file with an mtime six months (approximately) or more in the past has
a listing including a low-resolution timestamp.
"""
# Go with 7 months. That's more than 6 months.
then = self.now - (60 * 60 * 24 * 31 * 7)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 Apr 26 1973 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 Apr 27 1973 foo')
def test_oldSingleDigitDayOfMonth(self):
"""
A file with a high-resolution timestamp which falls on a day of the
month which can be represented by one decimal digit is formatted with
one padding 0 to preserve the columns which come after it.
"""
# A point about 7 months in the past, tweaked to fall on the first of a
# month so we test the case we want to test.
then = self.now - (60 * 60 * 24 * 31 * 7) + (60 * 60 * 24 * 5)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 May 01 1973 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 May 02 1973 foo')
def test_newFile(self):
"""
A file with an mtime fewer than six months (approximately) in the past
has a listing including a high-resolution timestamp excluding the year.
"""
# A point about three months in the past.
then = self.now - (60 * 60 * 24 * 31 * 3)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 Aug 28 17:33 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 Aug 29 09:33 foo')
def test_localeIndependent(self):
"""
The month name in the date is locale independent.
"""
# A point about three months in the past.
then = self.now - (60 * 60 * 24 * 31 * 3)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
# Fake that we're in a language where August is not Aug (e.g.: Spanish)
currentLocale = locale.getlocale()
locale.setlocale(locale.LC_ALL, "es_AR.UTF8")
self.addCleanup(locale.setlocale, locale.LC_ALL, currentLocale)
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 Aug 28 17:33 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 Aug 29 09:33 foo')
# if alternate locale is not available, the previous test will be
# skipped, please install this locale for it to run
currentLocale = locale.getlocale()
try:
try:
locale.setlocale(locale.LC_ALL, "es_AR.UTF8")
except locale.Error:
test_localeIndependent.skip = "The es_AR.UTF8 locale is not installed."
finally:
locale.setlocale(locale.LC_ALL, currentLocale)
def test_newSingleDigitDayOfMonth(self):
"""
A file with a high-resolution timestamp which falls on a day of the
month which can be represented by one decimal digit is formatted with
one padding 0 to preserve the columns which come after it.
"""
# A point about three months in the past, tweaked to fall on the first
# of a month so we test the case we want to test.
then = self.now - (60 * 60 * 24 * 31 * 3) + (60 * 60 * 24 * 4)
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
self.assertEqual(
self._lsInTimezone('America/New_York', stat),
'!--------- 0 0 0 0 Sep 01 17:33 foo')
self.assertEqual(
self._lsInTimezone('Pacific/Auckland', stat),
'!--------- 0 0 0 0 Sep 02 09:33 foo')
class StdioClientTests(TestCase):
"""
Tests for L{cftp.StdioClient}.
"""
def setUp(self):
"""
Create a L{cftp.StdioClient} hooked up to dummy transport and a fake
user database.
"""
class Connection:
pass
conn = Connection()
conn.transport = StringTransport()
conn.transport.localClosed = False
self.client = cftp.StdioClient(conn)
self.database = self.client._pwd = UserDatabase()
# Intentionally bypassing makeConnection - that triggers some code
# which uses features not provided by our dumb Connection fake.
self.client.transport = StringTransport()
def test_exec(self):
"""
The I{exec} command runs its arguments locally in a child process
using the user's shell.
"""
self.database.addUser(
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar',
sys.executable)
d = self.client._dispatchCommand("exec print 1 + 2")
d.addCallback(self.assertEqual, "3\n")
return d
def test_execWithoutShell(self):
"""
If the local user has no shell, the I{exec} command runs its arguments
using I{/bin/sh}.
"""
self.database.addUser(
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar', '')
d = self.client._dispatchCommand("exec echo hello")
d.addCallback(self.assertEqual, "hello\n")
return d
def test_bang(self):
"""
The I{exec} command is run for lines which start with C{"!"}.
"""
self.database.addUser(
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar',
'/bin/sh')
d = self.client._dispatchCommand("!echo hello")
d.addCallback(self.assertEqual, "hello\n")
return d
def setKnownConsoleSize(self, width, height):
"""
For the duration of this test, patch C{cftp}'s C{fcntl} module to return
a fixed width and height.
@param width: the width in characters
@type width: C{int}
@param height: the height in characters
@type height: C{int}
"""
import tty # local import to avoid win32 issues
class FakeFcntl(object):
def ioctl(self, fd, opt, mutate):
if opt != tty.TIOCGWINSZ:
self.fail("Only window-size queries supported.")
return struct.pack("4H", height, width, 0, 0)
self.patch(cftp, "fcntl", FakeFcntl())
def test_progressReporting(self):
"""
L{StdioClient._printProgressBar} prints a progress description,
including percent done, amount transferred, transfer rate, and time
remaining, all based the given start time, the given L{FileWrapper}'s
progress information and the reactor's current time.
"""
# Use a short, known console width because this simple test doesn't need
# to test the console padding.
self.setKnownConsoleSize(10, 34)
clock = self.client.reactor = Clock()
wrapped = StringIO("x")
wrapped.name = "sample"
wrapper = cftp.FileWrapper(wrapped)
wrapper.size = 1024 * 10
startTime = clock.seconds()
clock.advance(2.0)
wrapper.total += 4096
self.client._printProgressBar(wrapper, startTime)
self.assertEqual(self.client.transport.value(),
"\rsample 40% 4.0kB 2.0kBps 00:03 ")
def test_reportNoProgress(self):
"""
L{StdioClient._printProgressBar} prints a progress description that
indicates 0 bytes transferred if no bytes have been transferred and no
time has passed.
"""
self.setKnownConsoleSize(10, 34)
clock = self.client.reactor = Clock()
wrapped = StringIO("x")
wrapped.name = "sample"
wrapper = cftp.FileWrapper(wrapped)
startTime = clock.seconds()
self.client._printProgressBar(wrapper, startTime)
self.assertEqual(self.client.transport.value(),
"\rsample 0% 0.0B 0.0Bps 00:00 ")
class FileTransferTestRealm:
def __init__(self, testDir):
self.testDir = testDir
def requestAvatar(self, avatarID, mind, *interfaces):
a = FileTransferTestAvatar(self.testDir)
return interfaces[0], a, lambda: None
class SFTPTestProcess(protocol.ProcessProtocol):
"""
Protocol for testing cftp. Provides an interface between Python (where all
the tests are) and the cftp client process (which does the work that is
being tested).
"""
def __init__(self, onOutReceived):
"""
@param onOutReceived: A L{Deferred} to be fired as soon as data is
received from stdout.
"""
self.clearBuffer()
self.onOutReceived = onOutReceived
self.onProcessEnd = None
self._expectingCommand = None
self._processEnded = False
def clearBuffer(self):
"""
Clear any buffered data received from stdout. Should be private.
"""
self.buffer = ''
self._linesReceived = []
self._lineBuffer = ''
def outReceived(self, data):
"""
Called by Twisted when the cftp client prints data to stdout.
"""
log.msg('got %s' % data)
lines = (self._lineBuffer + data).split('\n')
self._lineBuffer = lines.pop(-1)
self._linesReceived.extend(lines)
# XXX - not strictly correct.
# We really want onOutReceived to fire after the first 'cftp>' prompt
# has been received. (See use in OurServerCmdLineClientTests.setUp)
if self.onOutReceived is not None:
d, self.onOutReceived = self.onOutReceived, None
d.callback(data)
self.buffer += data
self._checkForCommand()
def _checkForCommand(self):
prompt = 'cftp> '
if self._expectingCommand and self._lineBuffer == prompt:
buf = '\n'.join(self._linesReceived)
if buf.startswith(prompt):
buf = buf[len(prompt):]
self.clearBuffer()
d, self._expectingCommand = self._expectingCommand, None
d.callback(buf)
def errReceived(self, data):
"""
Called by Twisted when the cftp client prints data to stderr.
"""
log.msg('err: %s' % data)
def getBuffer(self):
"""
Return the contents of the buffer of data received from stdout.
"""
return self.buffer
def runCommand(self, command):
"""
Issue the given command via the cftp client. Return a C{Deferred} that
fires when the server returns a result. Note that the C{Deferred} will
callback even if the server returns some kind of error.
@param command: A string containing an sftp command.
@return: A C{Deferred} that fires when the sftp server returns a
result. The payload is the server's response string.
"""
self._expectingCommand = defer.Deferred()
self.clearBuffer()
self.transport.write(command + '\n')
return self._expectingCommand
def runScript(self, commands):
"""
Run each command in sequence and return a Deferred that fires when all
commands are completed.
@param commands: A list of strings containing sftp commands.
@return: A C{Deferred} that fires when all commands are completed. The
payload is a list of response strings from the server, in the same
order as the commands.
"""
sem = defer.DeferredSemaphore(1)
dl = [sem.run(self.runCommand, command) for command in commands]
return defer.gatherResults(dl)
def killProcess(self):
"""
Kill the process if it is still running.
If the process is still running, sends a KILL signal to the transport
and returns a C{Deferred} which fires when L{processEnded} is called.
@return: a C{Deferred}.
"""
if self._processEnded:
return defer.succeed(None)
self.onProcessEnd = defer.Deferred()
self.transport.signalProcess('KILL')
return self.onProcessEnd
def processEnded(self, reason):
"""
Called by Twisted when the cftp client process ends.
"""
self._processEnded = True
if self.onProcessEnd:
d, self.onProcessEnd = self.onProcessEnd, None
d.callback(None)
class CFTPClientTestBase(SFTPTestBase):
def setUp(self):
f = open('dsa_test.pub','w')
f.write(test_ssh.publicDSA_openssh)
f.close()
f = open('dsa_test','w')
f.write(test_ssh.privateDSA_openssh)
f.close()
os.chmod('dsa_test', 33152)
f = open('kh_test','w')
f.write('127.0.0.1 ' + test_ssh.publicRSA_openssh)
f.close()
return SFTPTestBase.setUp(self)
def startServer(self):
realm = FileTransferTestRealm(self.testDir)
p = portal.Portal(realm)
p.registerChecker(test_ssh.conchTestPublicKeyChecker())
fac = test_ssh.ConchTestServerFactory()
fac.portal = p
self.server = reactor.listenTCP(0, fac, interface="127.0.0.1")
def stopServer(self):
if not hasattr(self.server.factory, 'proto'):
return self._cbStopServer(None)
self.server.factory.proto.expectedLoseConnection = 1
d = defer.maybeDeferred(
self.server.factory.proto.transport.loseConnection)
d.addCallback(self._cbStopServer)
return d
def _cbStopServer(self, ignored):
return defer.maybeDeferred(self.server.stopListening)
def tearDown(self):
for f in ['dsa_test.pub', 'dsa_test', 'kh_test']:
try:
os.remove(f)
except:
pass
return SFTPTestBase.tearDown(self)
class OurServerCmdLineClientTests(CFTPClientTestBase):
def setUp(self):
CFTPClientTestBase.setUp(self)
self.startServer()
cmds = ('-p %i -l testuser '
'--known-hosts kh_test '
'--user-authentications publickey '
'--host-key-algorithms ssh-rsa '
'-i dsa_test '
'-a '
'-v '
'127.0.0.1')
port = self.server.getHost().port
cmds = test_conch._makeArgs((cmds % port).split(), mod='cftp')
log.msg('running %s %s' % (sys.executable, cmds))
d = defer.Deferred()
self.processProtocol = SFTPTestProcess(d)
d.addCallback(lambda _: self.processProtocol.clearBuffer())
env = os.environ.copy()
env['PYTHONPATH'] = os.pathsep.join(sys.path)
reactor.spawnProcess(self.processProtocol, sys.executable, cmds,
env=env)
return d
def tearDown(self):
d = self.stopServer()
d.addCallback(lambda _: self.processProtocol.killProcess())
return d
def _killProcess(self, ignored):
try:
self.processProtocol.transport.signalProcess('KILL')
except error.ProcessExitedAlready:
pass
def runCommand(self, command):
"""
Run the given command with the cftp client. Return a C{Deferred} that
fires when the command is complete. Payload is the server's output for
that command.
"""
return self.processProtocol.runCommand(command)
def runScript(self, *commands):
"""
Run the given commands with the cftp client. Returns a C{Deferred}
that fires when the commands are all complete. The C{Deferred}'s
payload is a list of output for each command.
"""
return self.processProtocol.runScript(commands)
def testCdPwd(self):
"""
Test that 'pwd' reports the current remote directory, that 'lpwd'
reports the current local directory, and that changing to a
subdirectory then changing to its parent leaves you in the original
remote directory.
"""
# XXX - not actually a unit test, see docstring.
homeDir = os.path.join(os.getcwd(), self.testDir)
d = self.runScript('pwd', 'lpwd', 'cd testDirectory', 'cd ..', 'pwd')
d.addCallback(lambda xs: xs[:3] + xs[4:])
d.addCallback(self.assertEqual,
[homeDir, os.getcwd(), '', homeDir])
return d
def testChAttrs(self):
"""
Check that 'ls -l' output includes the access permissions and that
this output changes appropriately with 'chmod'.
"""
def _check(results):
self.flushLoggedErrors()
self.assertTrue(results[0].startswith('-rw-r--r--'))
self.assertEqual(results[1], '')
self.assertTrue(results[2].startswith('----------'), results[2])
self.assertEqual(results[3], '')
d = self.runScript('ls -l testfile1', 'chmod 0 testfile1',
'ls -l testfile1', 'chmod 644 testfile1')
return d.addCallback(_check)
# XXX test chgrp/own
def testList(self):
"""
Check 'ls' works as expected. Checks for wildcards, hidden files,
listing directories and listing empty directories.
"""
def _check(results):
self.assertEqual(results[0], ['testDirectory', 'testRemoveFile',
'testRenameFile', 'testfile1'])
self.assertEqual(results[1], ['testDirectory', 'testRemoveFile',
'testRenameFile', 'testfile1'])
self.assertEqual(results[2], ['testRemoveFile', 'testRenameFile'])
self.assertEqual(results[3], ['.testHiddenFile', 'testRemoveFile',
'testRenameFile'])
self.assertEqual(results[4], [''])
d = self.runScript('ls', 'ls ../' + os.path.basename(self.testDir),
'ls *File', 'ls -a *File', 'ls -l testDirectory')
d.addCallback(lambda xs: [x.split('\n') for x in xs])
return d.addCallback(_check)
def testHelp(self):
"""
Check that running the '?' command returns help.
"""
d = self.runCommand('?')
d.addCallback(self.assertEqual,
cftp.StdioClient(None).cmd_HELP('').strip())
return d
def assertFilesEqual(self, name1, name2, msg=None):
"""
Assert that the files at C{name1} and C{name2} contain exactly the
same data.
"""
f1 = file(name1).read()
f2 = file(name2).read()
self.assertEqual(f1, f2, msg)
def testGet(self):
"""
Test that 'get' saves the remote file to the correct local location,
that the output of 'get' is correct and that 'rm' actually removes
the file.
"""
# XXX - not actually a unit test
expectedOutput = ("Transferred %s/%s/testfile1 to %s/test file2"
% (os.getcwd(), self.testDir, self.testDir))
def _checkGet(result):
self.assertTrue(result.endswith(expectedOutput))
self.assertFilesEqual(self.testDir + '/testfile1',
self.testDir + '/test file2',
"get failed")
return self.runCommand('rm "test file2"')
d = self.runCommand('get testfile1 "%s/test file2"' % (self.testDir,))
d.addCallback(_checkGet)
d.addCallback(lambda _: self.assertFalse(
os.path.exists(self.testDir + '/test file2')))
return d
def testWildcardGet(self):
"""
Test that 'get' works correctly when given wildcard parameters.
"""
def _check(ignored):
self.assertFilesEqual(self.testDir + '/testRemoveFile',
'testRemoveFile',
'testRemoveFile get failed')
self.assertFilesEqual(self.testDir + '/testRenameFile',
'testRenameFile',
'testRenameFile get failed')
d = self.runCommand('get testR*')
return d.addCallback(_check)
def testPut(self):
"""
Check that 'put' uploads files correctly and that they can be
successfully removed. Also check the output of the put command.
"""
# XXX - not actually a unit test
expectedOutput = ('Transferred %s/testfile1 to %s/%s/test"file2'
% (self.testDir, os.getcwd(), self.testDir))
def _checkPut(result):
self.assertFilesEqual(self.testDir + '/testfile1',
self.testDir + '/test"file2')
self.assertTrue(result.endswith(expectedOutput))
return self.runCommand('rm "test\\"file2"')
d = self.runCommand('put %s/testfile1 "test\\"file2"'
% (self.testDir,))
d.addCallback(_checkPut)
d.addCallback(lambda _: self.assertFalse(
os.path.exists(self.testDir + '/test"file2')))
return d
def test_putOverLongerFile(self):
"""
Check that 'put' uploads files correctly when overwriting a longer
file.
"""
# XXX - not actually a unit test
f = file(os.path.join(self.testDir, 'shorterFile'), 'w')
f.write("a")
f.close()
f = file(os.path.join(self.testDir, 'longerFile'), 'w')
f.write("bb")
f.close()
def _checkPut(result):
self.assertFilesEqual(self.testDir + '/shorterFile',
self.testDir + '/longerFile')
d = self.runCommand('put %s/shorterFile longerFile'
% (self.testDir,))
d.addCallback(_checkPut)
return d
def test_putMultipleOverLongerFile(self):
"""
Check that 'put' uploads files correctly when overwriting a longer
file and you use a wildcard to specify the files to upload.
"""
# XXX - not actually a unit test
os.mkdir(os.path.join(self.testDir, 'dir'))
f = file(os.path.join(self.testDir, 'dir', 'file'), 'w')
f.write("a")
f.close()
f = file(os.path.join(self.testDir, 'file'), 'w')
f.write("bb")
f.close()
def _checkPut(result):
self.assertFilesEqual(self.testDir + '/dir/file',
self.testDir + '/file')
d = self.runCommand('put %s/dir/*'
% (self.testDir,))
d.addCallback(_checkPut)
return d
def testWildcardPut(self):
"""
What happens if you issue a 'put' command and include a wildcard (i.e.
'*') in parameter? Check that all files matching the wildcard are
uploaded to the correct directory.
"""
def check(results):
self.assertEqual(results[0], '')
self.assertEqual(results[2], '')
self.assertFilesEqual(self.testDir + '/testRemoveFile',
self.testDir + '/../testRemoveFile',
'testRemoveFile get failed')
self.assertFilesEqual(self.testDir + '/testRenameFile',
self.testDir + '/../testRenameFile',
'testRenameFile get failed')
d = self.runScript('cd ..',
'put %s/testR*' % (self.testDir,),
'cd %s' % os.path.basename(self.testDir))
d.addCallback(check)
return d
def testLink(self):
"""
Test that 'ln' creates a file which appears as a link in the output of
'ls'. Check that removing the new file succeeds without output.
"""
def _check(results):
self.flushLoggedErrors()
self.assertEqual(results[0], '')
self.assertTrue(results[1].startswith('l'), 'link failed')
return self.runCommand('rm testLink')
d = self.runScript('ln testLink testfile1', 'ls -l testLink')
d.addCallback(_check)
d.addCallback(self.assertEqual, '')
return d
def testRemoteDirectory(self):
"""
Test that we can create and remove directories with the cftp client.
"""
def _check(results):
self.assertEqual(results[0], '')
self.assertTrue(results[1].startswith('d'))
return self.runCommand('rmdir testMakeDirectory')
d = self.runScript('mkdir testMakeDirectory',
'ls -l testMakeDirector?')
d.addCallback(_check)
d.addCallback(self.assertEqual, '')
return d
def test_existingRemoteDirectory(self):
"""
Test that a C{mkdir} on an existing directory fails with the
appropriate error, and doesn't log an useless error server side.
"""
def _check(results):
self.assertEqual(results[0], '')
self.assertEqual(results[1],
'remote error 11: mkdir failed')
d = self.runScript('mkdir testMakeDirectory',
'mkdir testMakeDirectory')
d.addCallback(_check)
return d
def testLocalDirectory(self):
"""
Test that we can create a directory locally and remove it with the
cftp client. This test works because the 'remote' server is running
out of a local directory.
"""
d = self.runCommand('lmkdir %s/testLocalDirectory' % (self.testDir,))
d.addCallback(self.assertEqual, '')
d.addCallback(lambda _: self.runCommand('rmdir testLocalDirectory'))
d.addCallback(self.assertEqual, '')
return d
def testRename(self):
"""
Test that we can rename a file.
"""
def _check(results):
self.assertEqual(results[0], '')
self.assertEqual(results[1], 'testfile2')
return self.runCommand('rename testfile2 testfile1')
d = self.runScript('rename testfile1 testfile2', 'ls testfile?')
d.addCallback(_check)
d.addCallback(self.assertEqual, '')
return d
class OurServerBatchFileTests(CFTPClientTestBase):
def setUp(self):
CFTPClientTestBase.setUp(self)
self.startServer()
def tearDown(self):
CFTPClientTestBase.tearDown(self)
return self.stopServer()
def _getBatchOutput(self, f):
fn = self.mktemp()
open(fn, 'w').write(f)
port = self.server.getHost().port
cmds = ('-p %i -l testuser '
'--known-hosts kh_test '
'--user-authentications publickey '
'--host-key-algorithms ssh-rsa '
'-i dsa_test '
'-a '
'-v -b %s 127.0.0.1') % (port, fn)
cmds = test_conch._makeArgs(cmds.split(), mod='cftp')[1:]
log.msg('running %s %s' % (sys.executable, cmds))
env = os.environ.copy()
env['PYTHONPATH'] = os.pathsep.join(sys.path)
self.server.factory.expectedLoseConnection = 1
d = getProcessOutputAndValue(sys.executable, cmds, env=env)
def _cleanup(res):
os.remove(fn)
return res
d.addCallback(lambda res: res[0])
d.addBoth(_cleanup)
return d
def testBatchFile(self):
"""Test whether batch file function of cftp ('cftp -b batchfile').
This works by treating the file as a list of commands to be run.
"""
cmds = """pwd
ls
exit
"""
def _cbCheckResult(res):
res = res.split('\n')
log.msg('RES %s' % str(res))
self.assertIn(self.testDir, res[1])
self.assertEqual(res[3:-2], ['testDirectory', 'testRemoveFile',
'testRenameFile', 'testfile1'])
d = self._getBatchOutput(cmds)
d.addCallback(_cbCheckResult)
return d
def testError(self):
"""Test that an error in the batch file stops running the batch.
"""
cmds = """chown 0 missingFile
pwd
exit
"""
def _cbCheckResult(res):
self.assertNotIn(self.testDir, res)
d = self._getBatchOutput(cmds)
d.addCallback(_cbCheckResult)
return d
def testIgnoredError(self):
"""Test that a minus sign '-' at the front of a line ignores
any errors.
"""
cmds = """-chown 0 missingFile
pwd
exit
"""
def _cbCheckResult(res):
self.assertIn(self.testDir, res)
d = self._getBatchOutput(cmds)
d.addCallback(_cbCheckResult)
return d
class OurServerSftpClientTests(CFTPClientTestBase):
"""
Test the sftp server against sftp command line client.
"""
def setUp(self):
CFTPClientTestBase.setUp(self)
return self.startServer()
def tearDown(self):
return self.stopServer()
def test_extendedAttributes(self):
"""
Test the return of extended attributes by the server: the sftp client
should ignore them, but still be able to parse the response correctly.
This test is mainly here to check that
L{filetransfer.FILEXFER_ATTR_EXTENDED} has the correct value.
"""
fn = self.mktemp()
open(fn, 'w').write("ls .\nexit")
port = self.server.getHost().port
oldGetAttr = FileTransferForTestAvatar._getAttrs
def _getAttrs(self, s):
attrs = oldGetAttr(self, s)
attrs["ext_foo"] = "bar"
return attrs
self.patch(FileTransferForTestAvatar, "_getAttrs", _getAttrs)
self.server.factory.expectedLoseConnection = True
cmds = ('-o', 'IdentityFile=dsa_test',
'-o', 'UserKnownHostsFile=kh_test',
'-o', 'HostKeyAlgorithms=ssh-rsa',
'-o', 'Port=%i' % (port,), '-b', fn, 'testuser@127.0.0.1')
d = getProcessOutputAndValue("sftp", cmds)
def check(result):
self.assertEqual(result[2], 0)
for i in ['testDirectory', 'testRemoveFile',
'testRenameFile', 'testfile1']:
self.assertIn(i, result[0])
return d.addCallback(check)
if unix is None or Crypto is None or pyasn1 is None or interfaces.IReactorProcess(reactor, None) is None:
if _reason is None:
_reason = "don't run w/o spawnProcess or PyCrypto or pyasn1"
OurServerCmdLineClientTests.skip = _reason
OurServerBatchFileTests.skip = _reason
OurServerSftpClientTests.skip = _reason
StdioClientTests.skip = _reason
SSHSessionTests.skip = _reason
else:
from twisted.python.procutils import which
if not which('sftp'):
OurServerSftpClientTests.skip = "no sftp command-line client available"

View File

@ -0,0 +1,279 @@
# Copyright (C) 2007-2008 Twisted Matrix Laboratories
# See LICENSE for details
"""
Test ssh/channel.py.
"""
from twisted.conch.ssh import channel
from twisted.trial import unittest
class MockTransport(object):
"""
A mock Transport. All we use is the getPeer() and getHost() methods.
Channels implement the ITransport interface, and their getPeer() and
getHost() methods return ('SSH', <transport's getPeer/Host value>) so
we need to implement these methods so they have something to draw
from.
"""
def getPeer(self):
return ('MockPeer',)
def getHost(self):
return ('MockHost',)
class MockConnection(object):
"""
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
that channels send, and when they try to close the connection.
@ivar data: a C{dict} mapping channel id #s to lists of data sent by that
channel.
@ivar extData: a C{dict} mapping channel id #s to lists of 2-tuples
(extended data type, data) sent by that channel.
@ivar closes: a C{dict} mapping channel id #s to True if that channel sent
a close message.
"""
transport = MockTransport()
def __init__(self):
self.data = {}
self.extData = {}
self.closes = {}
def logPrefix(self):
"""
Return our logging prefix.
"""
return "MockConnection"
def sendData(self, channel, data):
"""
Record the sent data.
"""
self.data.setdefault(channel, []).append(data)
def sendExtendedData(self, channel, type, data):
"""
Record the sent extended data.
"""
self.extData.setdefault(channel, []).append((type, data))
def sendClose(self, channel):
"""
Record that the channel sent a close message.
"""
self.closes[channel] = True
class ChannelTests(unittest.TestCase):
def setUp(self):
"""
Initialize the channel. remoteMaxPacket is 10 so that data is able
to be sent (the default of 0 means no data is sent because no packets
are made).
"""
self.conn = MockConnection()
self.channel = channel.SSHChannel(conn=self.conn,
remoteMaxPacket=10)
self.channel.name = 'channel'
def test_init(self):
"""
Test that SSHChannel initializes correctly. localWindowSize defaults
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
defaults (what OpenSSH uses for those variables).
The values in the second set of assertions are meaningless; they serve
only to verify that the instance variables are assigned in the correct
order.
"""
c = channel.SSHChannel(conn=self.conn)
self.assertEqual(c.localWindowSize, 131072)
self.assertEqual(c.localWindowLeft, 131072)
self.assertEqual(c.localMaxPacket, 32768)
self.assertEqual(c.remoteWindowLeft, 0)
self.assertEqual(c.remoteMaxPacket, 0)
self.assertEqual(c.conn, self.conn)
self.assertEqual(c.data, None)
self.assertEqual(c.avatar, None)
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
self.assertEqual(c2.localWindowSize, 1)
self.assertEqual(c2.localWindowLeft, 1)
self.assertEqual(c2.localMaxPacket, 2)
self.assertEqual(c2.remoteWindowLeft, 3)
self.assertEqual(c2.remoteMaxPacket, 4)
self.assertEqual(c2.conn, 5)
self.assertEqual(c2.data, 6)
self.assertEqual(c2.avatar, 7)
def test_str(self):
"""
Test that str(SSHChannel) works gives the channel name and local and
remote windows at a glance..
"""
self.assertEqual(str(self.channel), '<SSHChannel channel (lw 131072 '
'rw 0)>')
def test_logPrefix(self):
"""
Test that SSHChannel.logPrefix gives the name of the channel, the
local channel ID and the underlying connection.
"""
self.assertEqual(self.channel.logPrefix(), 'SSHChannel channel '
'(unknown) on MockConnection')
def test_addWindowBytes(self):
"""
Test that addWindowBytes adds bytes to the window and resumes writing
if it was paused.
"""
cb = [False]
def stubStartWriting():
cb[0] = True
self.channel.startWriting = stubStartWriting
self.channel.write('test')
self.channel.writeExtended(1, 'test')
self.channel.addWindowBytes(50)
self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
self.assertTrue(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(self.channel.buf, '')
self.assertEqual(self.conn.data[self.channel], ['test'])
self.assertEqual(self.channel.extBuf, [])
self.assertEqual(self.conn.extData[self.channel], [(1, 'test')])
cb[0] = False
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
self.channel.write('a'*80)
self.channel.loseConnection()
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
def test_requestReceived(self):
"""
Test that requestReceived handles requests by dispatching them to
request_* methods.
"""
self.channel.request_test_method = lambda data: data == ''
self.assertTrue(self.channel.requestReceived('test-method', ''))
self.assertFalse(self.channel.requestReceived('test-method', 'a'))
self.assertFalse(self.channel.requestReceived('bad-method', ''))
def test_closeReceieved(self):
"""
Test that the default closeReceieved closes the connection.
"""
self.assertFalse(self.channel.closing)
self.channel.closeReceived()
self.assertTrue(self.channel.closing)
def test_write(self):
"""
Test that write handles data correctly. Send data up to the size
of the remote window, splitting the data into packets of length
remoteMaxPacket.
"""
cb = [False]
def stubStopWriting():
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting
self.channel.write('d')
self.channel.write('a')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.write('ta')
data = self.conn.data[self.channel]
self.assertEqual(data, ['da', 'ta'])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.write('12345678901')
self.assertEqual(data, ['da', 'ta', '1234567890', '1'])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.write('123456')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, ['da', 'ta', '1234567890', '1', '12345'])
self.assertEqual(self.channel.buf, '6')
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeExtended(self):
"""
Test that writeExtended handles data correctly. Send extended data
up to the size of the window, splitting the extended data into packets
of length remoteMaxPacket.
"""
cb = [False]
def stubStopWriting():
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting
self.channel.writeExtended(1, 'd')
self.channel.writeExtended(1, 'a')
self.channel.writeExtended(2, 't')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.writeExtended(2, 'a')
data = self.conn.extData[self.channel]
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a')])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.writeExtended(3, '12345678901')
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
(3, '1234567890'), (3, '1')])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.writeExtended(4, '123456')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
(3, '1234567890'), (3, '1'), (4, '12345')])
self.assertEqual(self.channel.extBuf, [[4, '6']])
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeSequence(self):
"""
Test that writeSequence is equivalent to write(''.join(sequece)).
"""
self.channel.addWindowBytes(20)
self.channel.writeSequence(map(str, range(10)))
self.assertEqual(self.conn.data[self.channel], ['0123456789'])
def test_loseConnection(self):
"""
Tesyt that loseConnection() doesn't close the channel until all
the data is sent.
"""
self.channel.write('data')
self.channel.writeExtended(1, 'datadata')
self.channel.loseConnection()
self.assertEqual(self.conn.closes.get(self.channel), None)
self.channel.addWindowBytes(4) # send regular data
self.assertEqual(self.conn.closes.get(self.channel), None)
self.channel.addWindowBytes(8) # send extended data
self.assertTrue(self.conn.closes.get(self.channel))
def test_getPeer(self):
"""
Test that getPeer() returns ('SSH', <connection transport peer>).
"""
self.assertEqual(self.channel.getPeer(), ('SSH', 'MockPeer'))
def test_getHost(self):
"""
Test that getHost() returns ('SSH', <connection transport host>).
"""
self.assertEqual(self.channel.getHost(), ('SSH', 'MockHost'))

View File

@ -0,0 +1,892 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.checkers}.
"""
try:
import crypt
except ImportError:
cryptSkip = 'cannot run without crypt module'
else:
cryptSkip = None
import os, base64
from collections import namedtuple
from io import StringIO
from zope.interface.verify import verifyObject
from twisted.python import util
from twisted.python.failure import Failure
from twisted.python.reflect import requireModule
from twisted.trial.unittest import TestCase
from twisted.python.filepath import FilePath
from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
from twisted.cred.credentials import UsernamePassword, IUsernamePassword, \
SSHPrivateKey, ISSHPrivateKey
from twisted.cred.error import UnhandledCredentials, UnauthorizedLogin
from twisted.python.fakepwd import UserDatabase, ShadowDatabase
from twisted.test.test_process import MockOS
if requireModule('Crypto.Cipher.DES3') and requireModule('pyasn1'):
dependencySkip = None
from twisted.conch.ssh import keys
from twisted.conch import checkers
from twisted.conch.error import NotEnoughAuthentication, ValidPublicKey
from twisted.conch.test import keydata
else:
dependencySkip = "can't run without Crypto and PyASN1"
if getattr(os, 'geteuid', None) is None:
euidSkip = "Cannot run without effective UIDs (questionable)"
else:
euidSkip = None
class HelperTests(TestCase):
"""
Tests for helper functions L{verifyCryptedPassword}, L{_pwdGetByName} and
L{_shadowGetByName}.
"""
skip = cryptSkip or dependencySkip
def setUp(self):
self.mockos = MockOS()
def test_verifyCryptedPassword(self):
"""
L{verifyCryptedPassword} returns C{True} if the plaintext password
passed to it matches the encrypted password passed to it.
"""
password = 'secret string'
salt = 'salty'
crypted = crypt.crypt(password, salt)
self.assertTrue(
checkers.verifyCryptedPassword(crypted, password),
'%r supposed to be valid encrypted password for %r' % (
crypted, password))
def test_verifyCryptedPasswordMD5(self):
"""
L{verifyCryptedPassword} returns True if the provided cleartext password
matches the provided MD5 password hash.
"""
password = 'password'
salt = '$1$salt'
crypted = crypt.crypt(password, salt)
self.assertTrue(
checkers.verifyCryptedPassword(crypted, password),
'%r supposed to be valid encrypted password for %s' % (
crypted, password))
def test_refuteCryptedPassword(self):
"""
L{verifyCryptedPassword} returns C{False} if the plaintext password
passed to it does not match the encrypted password passed to it.
"""
password = 'string secret'
wrong = 'secret string'
crypted = crypt.crypt(password, password)
self.assertFalse(
checkers.verifyCryptedPassword(crypted, wrong),
'%r not supposed to be valid encrypted password for %s' % (
crypted, wrong))
def test_pwdGetByName(self):
"""
L{_pwdGetByName} returns a tuple of items from the UNIX /etc/passwd
database if the L{pwd} module is present.
"""
userdb = UserDatabase()
userdb.addUser(
'alice', 'secrit', 1, 2, 'first last', '/foo', '/bin/sh')
self.patch(checkers, 'pwd', userdb)
self.assertEqual(
checkers._pwdGetByName('alice'), userdb.getpwnam('alice'))
def test_pwdGetByNameWithoutPwd(self):
"""
If the C{pwd} module isn't present, L{_pwdGetByName} returns C{None}.
"""
self.patch(checkers, 'pwd', None)
self.assertIs(checkers._pwdGetByName('alice'), None)
def test_shadowGetByName(self):
"""
L{_shadowGetByName} returns a tuple of items from the UNIX /etc/shadow
database if the L{spwd} is present.
"""
userdb = ShadowDatabase()
userdb.addUser('bob', 'passphrase', 1, 2, 3, 4, 5, 6, 7)
self.patch(checkers, 'spwd', userdb)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.patch(util, 'os', self.mockos)
self.assertEqual(
checkers._shadowGetByName('bob'), userdb.getspnam('bob'))
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
def test_shadowGetByNameWithoutSpwd(self):
"""
L{_shadowGetByName} uses the C{shadow} module to return a tuple of items
from the UNIX /etc/shadow database if the C{spwd} module is not present
and the C{shadow} module is.
"""
userdb = ShadowDatabase()
userdb.addUser('bob', 'passphrase', 1, 2, 3, 4, 5, 6, 7)
self.patch(checkers, 'spwd', None)
self.patch(checkers, 'shadow', userdb)
self.patch(util, 'os', self.mockos)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.assertEqual(
checkers._shadowGetByName('bob'), userdb.getspnam('bob'))
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
def test_shadowGetByNameWithoutEither(self):
"""
L{_shadowGetByName} returns C{None} if neither C{spwd} nor C{shadow} is
present.
"""
self.patch(checkers, 'spwd', None)
self.patch(checkers, 'shadow', None)
self.assertIs(checkers._shadowGetByName('bob'), None)
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
class SSHPublicKeyDatabaseTests(TestCase):
"""
Tests for L{SSHPublicKeyDatabase}.
"""
skip = euidSkip or dependencySkip
def setUp(self):
self.checker = checkers.SSHPublicKeyDatabase()
self.key1 = base64.encodestring("foobar")
self.key2 = base64.encodestring("eggspam")
self.content = "t1 %s foo\nt2 %s egg\n" % (self.key1, self.key2)
self.mockos = MockOS()
self.mockos.path = FilePath(self.mktemp())
self.mockos.path.makedirs()
self.patch(util, 'os', self.mockos)
self.sshDir = self.mockos.path.child('.ssh')
self.sshDir.makedirs()
userdb = UserDatabase()
userdb.addUser(
'user', 'password', 1, 2, 'first last',
self.mockos.path.path, '/bin/shell')
self.checker._userdb = userdb
def test_deprecated(self):
"""
L{SSHPublicKeyDatabase} is deprecated as of version 15.0
"""
warningsShown = self.flushWarnings(
offendingFunctions=[self.setUp])
self.assertEqual(warningsShown[0]['category'], DeprecationWarning)
self.assertEqual(
warningsShown[0]['message'],
"twisted.conch.checkers.SSHPublicKeyDatabase "
"was deprecated in Twisted 15.0.0: Please use "
"twisted.conch.checkers.SSHPublicKeyChecker, "
"initialized with an instance of "
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead.")
self.assertEqual(len(warningsShown), 1)
def _testCheckKey(self, filename):
self.sshDir.child(filename).setContent(self.content)
user = UsernamePassword("user", "password")
user.blob = "foobar"
self.assertTrue(self.checker.checkKey(user))
user.blob = "eggspam"
self.assertTrue(self.checker.checkKey(user))
user.blob = "notallowed"
self.assertFalse(self.checker.checkKey(user))
def test_checkKey(self):
"""
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
authorized_keys file and check the keys against that file.
"""
self._testCheckKey("authorized_keys")
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_checkKey2(self):
"""
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
authorized_keys2 file and check the keys against that file.
"""
self._testCheckKey("authorized_keys2")
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_checkKeyAsRoot(self):
"""
If the key file is readable, L{SSHPublicKeyDatabase.checkKey} should
switch its uid/gid to the ones of the authenticated user.
"""
keyFile = self.sshDir.child("authorized_keys")
keyFile.setContent(self.content)
# Fake permission error by changing the mode
keyFile.chmod(0000)
self.addCleanup(keyFile.chmod, 0777)
# And restore the right mode when seteuid is called
savedSeteuid = self.mockos.seteuid
def seteuid(euid):
keyFile.chmod(0777)
return savedSeteuid(euid)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.patch(self.mockos, "seteuid", seteuid)
self.patch(util, 'os', self.mockos)
user = UsernamePassword("user", "password")
user.blob = "foobar"
self.assertTrue(self.checker.checkKey(user))
self.assertEqual(self.mockos.seteuidCalls, [0, 1, 0, 2345])
self.assertEqual(self.mockos.setegidCalls, [2, 1234])
def test_requestAvatarId(self):
"""
L{SSHPublicKeyDatabase.requestAvatarId} should return the avatar id
passed in if its C{_checkKey} method returns True.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(
'test', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
keys.Key.fromString(keydata.privateRSA_openssh).sign('foo'))
d = self.checker.requestAvatarId(credentials)
def _verify(avatarId):
self.assertEqual(avatarId, 'test')
return d.addCallback(_verify)
def test_requestAvatarIdWithoutSignature(self):
"""
L{SSHPublicKeyDatabase.requestAvatarId} should raise L{ValidPublicKey}
if the credentials represent a valid key without a signature. This
tells the user that the key is valid for login, but does not actually
allow that user to do so without a signature.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(
'test', 'ssh-rsa', keydata.publicRSA_openssh, None, None)
d = self.checker.requestAvatarId(credentials)
return self.assertFailure(d, ValidPublicKey)
def test_requestAvatarIdInvalidKey(self):
"""
If L{SSHPublicKeyDatabase.checkKey} returns False,
C{_cbRequestAvatarId} should raise L{UnauthorizedLogin}.
"""
def _checkKey(ignored):
return False
self.patch(self.checker, 'checkKey', _checkKey)
d = self.checker.requestAvatarId(None);
return self.assertFailure(d, UnauthorizedLogin)
def test_requestAvatarIdInvalidSignature(self):
"""
Valid keys with invalid signatures should cause
L{SSHPublicKeyDatabase.requestAvatarId} to return a {UnauthorizedLogin}
failure
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(
'test', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
keys.Key.fromString(keydata.privateDSA_openssh).sign('foo'))
d = self.checker.requestAvatarId(credentials)
return self.assertFailure(d, UnauthorizedLogin)
def test_requestAvatarIdNormalizeException(self):
"""
Exceptions raised while verifying the key should be normalized into an
C{UnauthorizedLogin} failure.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey('test', None, 'blob', 'sigData', 'sig')
d = self.checker.requestAvatarId(credentials)
def _verifyLoggedException(failure):
errors = self.flushLoggedErrors(keys.BadKeyError)
self.assertEqual(len(errors), 1)
return failure
d.addErrback(_verifyLoggedException)
return self.assertFailure(d, UnauthorizedLogin)
class SSHProtocolCheckerTests(TestCase):
"""
Tests for L{SSHProtocolChecker}.
"""
skip = dependencySkip
def test_registerChecker(self):
"""
L{SSHProcotolChecker.registerChecker} should add the given checker to
the list of registered checkers.
"""
checker = checkers.SSHProtocolChecker()
self.assertEqual(checker.credentialInterfaces, [])
checker.registerChecker(checkers.SSHPublicKeyDatabase(), )
self.assertEqual(checker.credentialInterfaces, [ISSHPrivateKey])
self.assertIsInstance(checker.checkers[ISSHPrivateKey],
checkers.SSHPublicKeyDatabase)
def test_registerCheckerWithInterface(self):
"""
If a apecific interface is passed into
L{SSHProtocolChecker.registerChecker}, that interface should be
registered instead of what the checker specifies in
credentialIntefaces.
"""
checker = checkers.SSHProtocolChecker()
self.assertEqual(checker.credentialInterfaces, [])
checker.registerChecker(checkers.SSHPublicKeyDatabase(),
IUsernamePassword)
self.assertEqual(checker.credentialInterfaces, [IUsernamePassword])
self.assertIsInstance(checker.checkers[IUsernamePassword],
checkers.SSHPublicKeyDatabase)
def test_requestAvatarId(self):
"""
L{SSHProtocolChecker.requestAvatarId} should defer to one if its
registered checkers to authenticate a user.
"""
checker = checkers.SSHProtocolChecker()
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
passwordDatabase.addUser('test', 'test')
checker.registerChecker(passwordDatabase)
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
def _callback(avatarId):
self.assertEqual(avatarId, 'test')
return d.addCallback(_callback)
def test_requestAvatarIdWithNotEnoughAuthentication(self):
"""
If the client indicates that it is never satisfied, by always returning
False from _areDone, then L{SSHProtocolChecker} should raise
L{NotEnoughAuthentication}.
"""
checker = checkers.SSHProtocolChecker()
def _areDone(avatarId):
return False
self.patch(checker, 'areDone', _areDone)
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
passwordDatabase.addUser('test', 'test')
checker.registerChecker(passwordDatabase)
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
return self.assertFailure(d, NotEnoughAuthentication)
def test_requestAvatarIdInvalidCredential(self):
"""
If the passed credentials aren't handled by any registered checker,
L{SSHProtocolChecker} should raise L{UnhandledCredentials}.
"""
checker = checkers.SSHProtocolChecker()
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
return self.assertFailure(d, UnhandledCredentials)
def test_areDone(self):
"""
The default L{SSHProcotolChecker.areDone} should simply return True.
"""
self.assertEqual(checkers.SSHProtocolChecker().areDone(None), True)
class UNIXPasswordDatabaseTests(TestCase):
"""
Tests for L{UNIXPasswordDatabase}.
"""
skip = cryptSkip or dependencySkip
def assertLoggedIn(self, d, username):
"""
Assert that the L{Deferred} passed in is called back with the value
'username'. This represents a valid login for this TestCase.
NOTE: To work, this method's return value must be returned from the
test method, or otherwise hooked up to the test machinery.
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
@type d: L{Deferred}
@rtype: L{Deferred}
"""
result = []
d.addBoth(result.append)
self.assertEqual(len(result), 1, "login incomplete")
if isinstance(result[0], Failure):
result[0].raiseException()
self.assertEqual(result[0], username)
def test_defaultCheckers(self):
"""
L{UNIXPasswordDatabase} with no arguments has checks the C{pwd} database
and then the C{spwd} database.
"""
checker = checkers.UNIXPasswordDatabase()
def crypted(username, password):
salt = crypt.crypt(password, username)
crypted = crypt.crypt(password, '$1$' + salt)
return crypted
pwd = UserDatabase()
pwd.addUser('alice', crypted('alice', 'password'),
1, 2, 'foo', '/foo', '/bin/sh')
# x and * are convention for "look elsewhere for the password"
pwd.addUser('bob', 'x', 1, 2, 'bar', '/bar', '/bin/sh')
spwd = ShadowDatabase()
spwd.addUser('alice', 'wrong', 1, 2, 3, 4, 5, 6, 7)
spwd.addUser('bob', crypted('bob', 'password'),
8, 9, 10, 11, 12, 13, 14)
self.patch(checkers, 'pwd', pwd)
self.patch(checkers, 'spwd', spwd)
mockos = MockOS()
self.patch(util, 'os', mockos)
mockos.euid = 2345
mockos.egid = 1234
cred = UsernamePassword("alice", "password")
self.assertLoggedIn(checker.requestAvatarId(cred), 'alice')
self.assertEqual(mockos.seteuidCalls, [])
self.assertEqual(mockos.setegidCalls, [])
cred.username = "bob"
self.assertLoggedIn(checker.requestAvatarId(cred), 'bob')
self.assertEqual(mockos.seteuidCalls, [0, 2345])
self.assertEqual(mockos.setegidCalls, [0, 1234])
def assertUnauthorizedLogin(self, d):
"""
Asserts that the L{Deferred} passed in is erred back with an
L{UnauthorizedLogin} L{Failure}. This reprsents an invalid login for
this TestCase.
NOTE: To work, this method's return value must be returned from the
test method, or otherwise hooked up to the test machinery.
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
@type d: L{Deferred}
@rtype: L{None}
"""
self.assertRaises(
checkers.UnauthorizedLogin, self.assertLoggedIn, d, 'bogus value')
def test_passInCheckers(self):
"""
L{UNIXPasswordDatabase} takes a list of functions to check for UNIX
user information.
"""
password = crypt.crypt('secret', 'secret')
userdb = UserDatabase()
userdb.addUser('anybody', password, 1, 2, 'foo', '/bar', '/bin/sh')
checker = checkers.UNIXPasswordDatabase([userdb.getpwnam])
self.assertLoggedIn(
checker.requestAvatarId(UsernamePassword('anybody', 'secret')),
'anybody')
def test_verifyPassword(self):
"""
If the encrypted password provided by the getpwnam function is valid
(verified by the L{verifyCryptedPassword} function), we callback the
C{requestAvatarId} L{Deferred} with the username.
"""
def verifyCryptedPassword(crypted, pw):
return crypted == pw
def getpwnam(username):
return [username, username]
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword('username', 'username')
self.assertLoggedIn(checker.requestAvatarId(credential), 'username')
def test_failOnKeyError(self):
"""
If the getpwnam function raises a KeyError, the login fails with an
L{UnauthorizedLogin} exception.
"""
def getpwnam(username):
raise KeyError(username)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword('username', 'username')
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
def test_failOnBadPassword(self):
"""
If the verifyCryptedPassword function doesn't verify the password, the
login fails with an L{UnauthorizedLogin} exception.
"""
def verifyCryptedPassword(crypted, pw):
return False
def getpwnam(username):
return [username, username]
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword('username', 'username')
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
def test_loopThroughFunctions(self):
"""
UNIXPasswordDatabase.requestAvatarId loops through each getpwnam
function associated with it and returns a L{Deferred} which fires with
the result of the first one which returns a value other than None.
ones do not verify the password.
"""
def verifyCryptedPassword(crypted, pw):
return crypted == pw
def getpwnam1(username):
return [username, 'not the password']
def getpwnam2(username):
return [username, username]
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam1, getpwnam2])
credential = UsernamePassword('username', 'username')
self.assertLoggedIn(checker.requestAvatarId(credential), 'username')
def test_failOnSpecial(self):
"""
If the password returned by any function is C{""}, C{"x"}, or C{"*"} it
is not compared against the supplied password. Instead it is skipped.
"""
pwd = UserDatabase()
pwd.addUser('alice', '', 1, 2, '', 'foo', 'bar')
pwd.addUser('bob', 'x', 1, 2, '', 'foo', 'bar')
pwd.addUser('carol', '*', 1, 2, '', 'foo', 'bar')
self.patch(checkers, 'pwd', pwd)
checker = checkers.UNIXPasswordDatabase([checkers._pwdGetByName])
cred = UsernamePassword('alice', '')
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
cred = UsernamePassword('bob', 'x')
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
cred = UsernamePassword('carol', '*')
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
class AuthorizedKeyFileReaderTests(TestCase):
"""
Tests for L{checkers.readAuthorizedKeyFile}
"""
skip = dependencySkip
def test_ignoresComments(self):
"""
L{checkers.readAuthorizedKeyFile} does not attempt to turn comments
into keys
"""
fileobj = StringIO(u'# this comment is ignored\n'
u'this is not\n'
u'# this is again\n'
u'and this is not')
result = checkers.readAuthorizedKeyFile(fileobj, lambda x: x)
self.assertEqual(['this is not', 'and this is not'], list(result))
def test_ignoresLeadingWhitespaceAndEmptyLines(self):
"""
L{checkers.readAuthorizedKeyFile} ignores leading whitespace in
lines, as well as empty lines
"""
fileobj = StringIO(u"""
# ignore
not ignored
""")
result = checkers.readAuthorizedKeyFile(fileobj, parseKey=lambda x: x)
self.assertEqual(['not ignored'], list(result))
def test_ignoresUnparsableKeys(self):
"""
L{checkers.readAuthorizedKeyFile} does not raise an exception
when a key fails to parse (raises a
L{twisted.conch.ssh.keys.BadKeyError}), but rather just keeps going
"""
def failOnSome(line):
if line.startswith('f'):
raise keys.BadKeyError('failed to parse')
return line
fileobj = StringIO(u'failed key\ngood key')
result = checkers.readAuthorizedKeyFile(fileobj,
parseKey=failOnSome)
self.assertEqual(['good key'], list(result))
class InMemorySSHKeyDBTests(TestCase):
"""
Tests for L{checkers.InMemorySSHKeyDB}
"""
skip = dependencySkip
def test_implementsInterface(self):
"""
L{checkers.InMemorySSHKeyDB} implements
L{checkers.IAuthorizedKeysDB}
"""
keydb = checkers.InMemorySSHKeyDB({'alice': ['key']})
verifyObject(checkers.IAuthorizedKeysDB, keydb)
def test_noKeysForUnauthorizedUser(self):
"""
If the user is not in the mapping provided to
L{checkers.InMemorySSHKeyDB}, an empty iterator is returned
by L{checkers.InMemorySSHKeyDB.getAuthorizedKeys}
"""
keydb = checkers.InMemorySSHKeyDB({'alice': ['keys']})
self.assertEqual([], list(keydb.getAuthorizedKeys('bob')))
def test_allKeysForAuthorizedUser(self):
"""
If the user is in the mapping provided to
L{checkers.InMemorySSHKeyDB}, an iterator with all the keys
is returned by L{checkers.InMemorySSHKeyDB.getAuthorizedKeys}
"""
keydb = checkers.InMemorySSHKeyDB({'alice': ['a', 'b']})
self.assertEqual(['a', 'b'], list(keydb.getAuthorizedKeys('alice')))
class UNIXAuthorizedKeysFilesTests(TestCase):
"""
Tests for L{checkers.UNIXAuthorizedKeysFiles}.
"""
skip = dependencySkip
def setUp(self):
mockos = MockOS()
mockos.path = FilePath(self.mktemp())
mockos.path.makedirs()
self.userdb = UserDatabase()
self.userdb.addUser('alice', 'password', 1, 2, 'alice lastname',
mockos.path.path, '/bin/shell')
self.sshDir = mockos.path.child('.ssh')
self.sshDir.makedirs()
authorizedKeys = self.sshDir.child('authorized_keys')
authorizedKeys.setContent('key 1\nkey 2')
self.expectedKeys = ['key 1', 'key 2']
def test_implementsInterface(self):
"""
L{checkers.UNIXAuthorizedKeysFiles} implements
L{checkers.IAuthorizedKeysDB}.
"""
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb)
verifyObject(checkers.IAuthorizedKeysDB, keydb)
def test_noKeysForUnauthorizedUser(self):
"""
If the user is not in the user database provided to
L{checkers.UNIXAuthorizedKeysFiles}, an empty iterator is returned
by L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys}.
"""
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
parseKey=lambda x: x)
self.assertEqual([], list(keydb.getAuthorizedKeys('bob')))
def test_allKeysInAllAuthorizedFilesForAuthorizedUser(self):
"""
If the user is in the user database provided to
L{checkers.UNIXAuthorizedKeysFiles}, an iterator with all the keys in
C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2} is returned
by L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys}.
"""
self.sshDir.child('authorized_keys2').setContent('key 3')
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
parseKey=lambda x: x)
self.assertEqual(self.expectedKeys + ['key 3'],
list(keydb.getAuthorizedKeys('alice')))
def test_ignoresNonexistantFile(self):
"""
L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys} returns only
the keys in C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2}
if they exist.
"""
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
parseKey=lambda x: x)
self.assertEqual(self.expectedKeys,
list(keydb.getAuthorizedKeys('alice')))
def test_ignoresUnreadableFile(self):
"""
L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys} returns only
the keys in C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2}
if they are readable.
"""
self.sshDir.child('authorized_keys2').makedirs()
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
parseKey=lambda x: x)
self.assertEqual(self.expectedKeys,
list(keydb.getAuthorizedKeys('alice')))
_KeyDB = namedtuple('KeyDB', ['getAuthorizedKeys'])
class _DummyException(Exception):
"""
Fake exception to be used for testing.
"""
pass
class SSHPublicKeyCheckerTests(TestCase):
"""
Tests for L{checkers.SSHPublicKeyChecker}.
"""
skip = dependencySkip
def setUp(self):
self.credentials = SSHPrivateKey(
'alice', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
keys.Key.fromString(keydata.privateRSA_openssh).sign('foo'))
self.keydb = _KeyDB(lambda _: [
keys.Key.fromString(keydata.publicRSA_openssh)])
self.checker = checkers.SSHPublicKeyChecker(self.keydb)
def test_credentialsWithoutSignature(self):
"""
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
credentials that do not have a signature fails with L{ValidPublicKey}.
"""
self.credentials.signature = None
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
ValidPublicKey)
def test_credentialsWithBadKey(self):
"""
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
credentials that have a bad key fails with L{keys.BadKeyError}.
"""
self.credentials.blob = ''
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
keys.BadKeyError)
def test_credentialsNoMatchingKey(self):
"""
If L{checkers.IAuthorizedKeysDB.getAuthorizedKeys} returns no keys
that match the credentials,
L{checkers.SSHPublicKeyChecker.requestAvatarId} fails with
L{UnauthorizedLogin}.
"""
self.credentials.blob = keydata.publicDSA_openssh
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
UnauthorizedLogin)
def test_credentialsInvalidSignature(self):
"""
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
credentials that are incorrectly signed fails with
L{UnauthorizedLogin}.
"""
self.credentials.signature = (
keys.Key.fromString(keydata.privateDSA_openssh).sign('foo'))
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
UnauthorizedLogin)
def test_failureVerifyingKey(self):
"""
If L{keys.Key.verify} raises an exception,
L{checkers.SSHPublicKeyChecker.requestAvatarId} fails with
L{UnauthorizedLogin}.
"""
def fail(*args, **kwargs):
raise _DummyException()
self.patch(keys.Key, 'verify', fail)
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
UnauthorizedLogin)
self.flushLoggedErrors(_DummyException)
def test_usernameReturnedOnSuccess(self):
"""
L{checker.SSHPublicKeyChecker.requestAvatarId}, if successful,
callbacks with the username.
"""
d = self.checker.requestAvatarId(self.credentials)
self.assertEqual('alice', self.successResultOf(d))

View File

@ -0,0 +1,339 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.scripts.ckeygen}.
"""
import getpass
import sys
from StringIO import StringIO
from twisted.python.reflect import requireModule
if requireModule('Crypto') and requireModule('pyasn1'):
from twisted.conch.ssh.keys import Key, BadKeyError
from twisted.conch.scripts.ckeygen import (
changePassPhrase, displayPublicKey, printFingerprint, _saveKey)
else:
skip = "PyCrypto and pyasn1 required for twisted.conch.scripts.ckeygen."
from twisted.python.filepath import FilePath
from twisted.trial.unittest import TestCase
from twisted.conch.test.keydata import (
publicRSA_openssh, privateRSA_openssh, privateRSA_openssh_encrypted)
def makeGetpass(*passphrases):
"""
Return a callable to patch C{getpass.getpass}. Yields a passphrase each
time called. Use case is to provide an old, then new passphrase(s) as if
requested interactively.
@param passphrases: The list of passphrases returned, one per each call.
"""
passphrases = iter(passphrases)
def fakeGetpass(_):
return passphrases.next()
return fakeGetpass
class KeyGenTests(TestCase):
"""
Tests for various functions used to implement the I{ckeygen} script.
"""
def setUp(self):
"""
Patch C{sys.stdout} with a L{StringIO} instance to tests can make
assertions about what's printed.
"""
self.stdout = StringIO()
self.patch(sys, 'stdout', self.stdout)
def test_printFingerprint(self):
"""
L{printFingerprint} writes a line to standard out giving the number of
bits of the key, its fingerprint, and the basename of the file from it
was read.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
printFingerprint({'filename': filename})
self.assertEqual(
self.stdout.getvalue(),
'768 3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af temp\n')
def test_saveKey(self):
"""
L{_saveKey} writes the private and public parts of a key to two
different files and writes a report of this to standard out.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_rsa').path
key = Key.fromString(privateRSA_openssh)
_saveKey(
key.keyObject,
{'filename': filename, 'pass': 'passphrase'})
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint is:\n"
"3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af\n" % (
filename,
filename))
self.assertEqual(
key.fromString(
base.child('id_rsa').getContent(), None, 'passphrase'),
key)
self.assertEqual(
Key.fromString(base.child('id_rsa.pub').getContent()),
key.public())
def test_saveKeyEmptyPassphrase(self):
"""
L{_saveKey} will choose an empty string for the passphrase if
no-passphrase is C{True}.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_rsa').path
key = Key.fromString(privateRSA_openssh)
_saveKey(
key.keyObject,
{'filename': filename, 'no-passphrase': True})
self.assertEqual(
key.fromString(
base.child('id_rsa').getContent(), None, b''),
key)
def test_displayPublicKey(self):
"""
L{displayPublicKey} prints out the public key associated with a given
private key.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh)
displayPublicKey({'filename': filename})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
pubKey.toString('openssh'))
def test_displayPublicKeyEncrypted(self):
"""
L{displayPublicKey} prints out the public key associated with a given
private key using the given passphrase when it's encrypted.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh_encrypted)
displayPublicKey({'filename': filename, 'pass': 'encrypted'})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
pubKey.toString('openssh'))
def test_displayPublicKeyEncryptedPassphrasePrompt(self):
"""
L{displayPublicKey} prints out the public key associated with a given
private key, asking for the passphrase when it's encrypted.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh_encrypted)
self.patch(getpass, 'getpass', lambda x: 'encrypted')
displayPublicKey({'filename': filename})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
pubKey.toString('openssh'))
def test_displayPublicKeyWrongPassphrase(self):
"""
L{displayPublicKey} fails with a L{BadKeyError} when trying to decrypt
an encrypted key with the wrong password.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
self.assertRaises(
BadKeyError, displayPublicKey,
{'filename': filename, 'pass': 'wrong'})
def test_changePassphrase(self):
"""
L{changePassPhrase} allows a user to change the passphrase of a
private key interactively.
"""
oldNewConfirm = makeGetpass('encrypted', 'newpass', 'newpass')
self.patch(getpass, 'getpass', oldNewConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({'filename': filename})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
self.assertNotEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseWithOld(self):
"""
L{changePassPhrase} allows a user to change the passphrase of a
private key, providing the old passphrase and prompting for new one.
"""
newConfirm = makeGetpass('newpass', 'newpass')
self.patch(getpass, 'getpass', newConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({'filename': filename, 'pass': 'encrypted'})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
self.assertNotEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseWithBoth(self):
"""
L{changePassPhrase} allows a user to change the passphrase of a private
key by providing both old and new passphrases without prompting.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase(
{'filename': filename, 'pass': 'encrypted',
'newpass': 'newencrypt'})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
self.assertNotEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseWrongPassphrase(self):
"""
L{changePassPhrase} exits if passed an invalid old passphrase when
trying to change the passphrase of a private key.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename, 'pass': 'wrong'})
self.assertEqual('Could not change passphrase: old passphrase error',
str(error))
self.assertEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseEmptyGetPass(self):
"""
L{changePassPhrase} exits if no passphrase is specified for the
C{getpass} call and the key is encrypted.
"""
self.patch(getpass, 'getpass', makeGetpass(''))
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
error = self.assertRaises(
SystemExit, changePassPhrase, {'filename': filename})
self.assertEqual(
'Could not change passphrase: Passphrase must be provided '
'for an encrypted key',
str(error))
self.assertEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseBadKey(self):
"""
L{changePassPhrase} exits if the file specified points to an invalid
key.
"""
filename = self.mktemp()
FilePath(filename).setContent('foobar')
error = self.assertRaises(
SystemExit, changePassPhrase, {'filename': filename})
self.assertEqual(
"Could not change passphrase: cannot guess the type of 'foobar'",
str(error))
self.assertEqual('foobar', FilePath(filename).getContent())
def test_changePassphraseCreateError(self):
"""
L{changePassPhrase} doesn't modify the key file if an unexpected error
happens when trying to create the key with the new passphrase.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh)
def toString(*args, **kwargs):
raise RuntimeError('oops')
self.patch(Key, 'toString', toString)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename,
'newpass': 'newencrypt'})
self.assertEqual(
'Could not change passphrase: oops', str(error))
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
def test_changePassphraseEmptyStringError(self):
"""
L{changePassPhrase} doesn't modify the key file if C{toString} returns
an empty string.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh)
def toString(*args, **kwargs):
return ''
self.patch(Key, 'toString', toString)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename, 'newpass': 'newencrypt'})
self.assertEqual(
"Could not change passphrase: "
"cannot guess the type of ''", str(error))
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
def test_changePassphrasePublicKey(self):
"""
L{changePassPhrase} exits when trying to change the passphrase on a
public key, and doesn't change the file.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename, 'newpass': 'pass'})
self.assertEqual(
'Could not change passphrase: key not encrypted', str(error))
self.assertEqual(publicRSA_openssh, FilePath(filename).getContent())

View File

@ -0,0 +1,577 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import os, sys, socket
from itertools import count
from zope.interface import implementer
from twisted.cred import portal
from twisted.internet import reactor, defer, protocol
from twisted.internet.error import ProcessExitedAlready
from twisted.internet.task import LoopingCall
from twisted.python import log, runtime
from twisted.trial import unittest
from twisted.conch.error import ConchError
from twisted.conch.avatar import ConchUser
from twisted.conch.ssh.session import ISession, SSHSession, wrapProtocol
try:
from twisted.conch.scripts.conch import SSHSession as StdioInteractingSession
except ImportError, e:
StdioInteractingSession = None
_reason = str(e)
del e
from twisted.conch.test.test_ssh import ConchTestRealm
from twisted.python.procutils import which
from twisted.conch.test.keydata import publicRSA_openssh, privateRSA_openssh
from twisted.conch.test.keydata import publicDSA_openssh, privateDSA_openssh
from twisted.conch.test.test_ssh import Crypto, pyasn1
try:
from twisted.conch.test.test_ssh import ConchTestServerFactory, \
conchTestPublicKeyChecker
except ImportError:
pass
class FakeStdio(object):
"""
A fake for testing L{twisted.conch.scripts.conch.SSHSession.eofReceived} and
L{twisted.conch.scripts.cftp.SSHSession.eofReceived}.
@ivar writeConnLost: A flag which records whether L{loserWriteConnection}
has been called.
"""
writeConnLost = False
def loseWriteConnection(self):
"""
Record the call to loseWriteConnection.
"""
self.writeConnLost = True
class StdioInteractingSessionTests(unittest.TestCase):
"""
Tests for L{twisted.conch.scripts.conch.SSHSession}.
"""
if StdioInteractingSession is None:
skip = _reason
def test_eofReceived(self):
"""
L{twisted.conch.scripts.conch.SSHSession.eofReceived} loses the
write half of its stdio connection.
"""
stdio = FakeStdio()
channel = StdioInteractingSession()
channel.stdio = stdio
channel.eofReceived()
self.assertTrue(stdio.writeConnLost)
class Echo(protocol.Protocol):
def connectionMade(self):
log.msg('ECHO CONNECTION MADE')
def connectionLost(self, reason):
log.msg('ECHO CONNECTION DONE')
def dataReceived(self, data):
self.transport.write(data)
if '\n' in data:
self.transport.loseConnection()
class EchoFactory(protocol.Factory):
protocol = Echo
class ConchTestOpenSSHProcess(protocol.ProcessProtocol):
"""
Test protocol for launching an OpenSSH client process.
@ivar deferred: Set by whatever uses this object. Accessed using
L{_getDeferred}, which destroys the value so the Deferred is not
fired twice. Fires when the process is terminated.
"""
deferred = None
buf = ''
def _getDeferred(self):
d, self.deferred = self.deferred, None
return d
def outReceived(self, data):
self.buf += data
def processEnded(self, reason):
"""
Called when the process has ended.
@param reason: a Failure giving the reason for the process' end.
"""
if reason.value.exitCode != 0:
self._getDeferred().errback(
ConchError("exit code was not 0: %s" %
reason.value.exitCode))
else:
buf = self.buf.replace('\r\n', '\n')
self._getDeferred().callback(buf)
class ConchTestForwardingProcess(protocol.ProcessProtocol):
"""
Manages a third-party process which launches a server.
Uses L{ConchTestForwardingPort} to connect to the third-party server.
Once L{ConchTestForwardingPort} has disconnected, kill the process and fire
a Deferred with the data received by the L{ConchTestForwardingPort}.
@ivar deferred: Set by whatever uses this object. Accessed using
L{_getDeferred}, which destroys the value so the Deferred is not
fired twice. Fires when the process is terminated.
"""
deferred = None
def __init__(self, port, data):
"""
@type port: C{int}
@param port: The port on which the third-party server is listening.
(it is assumed that the server is running on localhost).
@type data: C{str}
@param data: This is sent to the third-party server. Must end with '\n'
in order to trigger a disconnect.
"""
self.port = port
self.buffer = None
self.data = data
def _getDeferred(self):
d, self.deferred = self.deferred, None
return d
def connectionMade(self):
self._connect()
def _connect(self):
"""
Connect to the server, which is often a third-party process.
Tries to reconnect if it fails because we have no way of determining
exactly when the port becomes available for listening -- we can only
know when the process starts.
"""
cc = protocol.ClientCreator(reactor, ConchTestForwardingPort, self,
self.data)
d = cc.connectTCP('127.0.0.1', self.port)
d.addErrback(self._ebConnect)
return d
def _ebConnect(self, f):
reactor.callLater(.1, self._connect)
def forwardingPortDisconnected(self, buffer):
"""
The network connection has died; save the buffer of output
from the network and attempt to quit the process gracefully,
and then (after the reactor has spun) send it a KILL signal.
"""
self.buffer = buffer
self.transport.write('\x03')
self.transport.loseConnection()
reactor.callLater(0, self._reallyDie)
def _reallyDie(self):
try:
self.transport.signalProcess('KILL')
except ProcessExitedAlready:
pass
def processEnded(self, reason):
"""
Fire the Deferred at self.deferred with the data collected
from the L{ConchTestForwardingPort} connection, if any.
"""
self._getDeferred().callback(self.buffer)
class ConchTestForwardingPort(protocol.Protocol):
"""
Connects to server launched by a third-party process (managed by
L{ConchTestForwardingProcess}) sends data, then reports whatever it
received back to the L{ConchTestForwardingProcess} once the connection
is ended.
"""
def __init__(self, protocol, data):
"""
@type protocol: L{ConchTestForwardingProcess}
@param protocol: The L{ProcessProtocol} which made this connection.
@type data: str
@param data: The data to be sent to the third-party server.
"""
self.protocol = protocol
self.data = data
def connectionMade(self):
self.buffer = ''
self.transport.write(self.data)
def dataReceived(self, data):
self.buffer += data
def connectionLost(self, reason):
self.protocol.forwardingPortDisconnected(self.buffer)
def _makeArgs(args, mod="conch"):
start = [sys.executable, '-c'
"""
### Twisted Preamble
import sys, os
path = os.path.abspath(sys.argv[0])
while os.path.dirname(path) != path:
if os.path.basename(path).startswith('Twisted'):
sys.path.insert(0, path)
break
path = os.path.dirname(path)
from twisted.conch.scripts.%s import run
run()""" % mod]
return start + list(args)
class ConchServerSetupMixin:
if not Crypto:
skip = "can't run w/o PyCrypto"
if not pyasn1:
skip = "Cannot run without PyASN1"
realmFactory = staticmethod(lambda: ConchTestRealm('testuser'))
def _createFiles(self):
for f in ['rsa_test','rsa_test.pub','dsa_test','dsa_test.pub',
'kh_test']:
if os.path.exists(f):
os.remove(f)
open('rsa_test','w').write(privateRSA_openssh)
open('rsa_test.pub','w').write(publicRSA_openssh)
open('dsa_test.pub','w').write(publicDSA_openssh)
open('dsa_test','w').write(privateDSA_openssh)
os.chmod('dsa_test', 33152)
os.chmod('rsa_test', 33152)
open('kh_test','w').write('127.0.0.1 '+publicRSA_openssh)
def _getFreePort(self):
s = socket.socket()
s.bind(('', 0))
port = s.getsockname()[1]
s.close()
return port
def _makeConchFactory(self):
"""
Make a L{ConchTestServerFactory}, which allows us to start a
L{ConchTestServer} -- i.e. an actually listening conch.
"""
realm = self.realmFactory()
p = portal.Portal(realm)
p.registerChecker(conchTestPublicKeyChecker())
factory = ConchTestServerFactory()
factory.portal = p
return factory
def setUp(self):
self._createFiles()
self.conchFactory = self._makeConchFactory()
self.conchFactory.expectedLoseConnection = 1
self.conchServer = reactor.listenTCP(0, self.conchFactory,
interface="127.0.0.1")
self.echoServer = reactor.listenTCP(0, EchoFactory())
self.echoPort = self.echoServer.getHost().port
self.echoServerV6 = reactor.listenTCP(0, EchoFactory(), interface="::1")
self.echoPortV6 = self.echoServerV6.getHost().port
def tearDown(self):
try:
self.conchFactory.proto.done = 1
except AttributeError:
pass
else:
self.conchFactory.proto.transport.loseConnection()
return defer.gatherResults([
defer.maybeDeferred(self.conchServer.stopListening),
defer.maybeDeferred(self.echoServer.stopListening),
defer.maybeDeferred(self.echoServerV6.stopListening)])
class ForwardingMixin(ConchServerSetupMixin):
"""
Template class for tests of the Conch server's ability to forward arbitrary
protocols over SSH.
These tests are integration tests, not unit tests. They launch a Conch
server, a custom TCP server (just an L{EchoProtocol}) and then call
L{execute}.
L{execute} is implemented by subclasses of L{ForwardingMixin}. It should
cause an SSH client to connect to the Conch server, asking it to forward
data to the custom TCP server.
"""
def test_exec(self):
"""
Test that we can use whatever client to send the command "echo goodbye"
to the Conch server. Make sure we receive "goodbye" back from the
server.
"""
d = self.execute('echo goodbye', ConchTestOpenSSHProcess())
return d.addCallback(self.assertEqual, 'goodbye\n')
def test_localToRemoteForwarding(self):
"""
Test that we can use whatever client to forward a local port to a
specified port on the server.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, 'test\n')
d = self.execute('', process,
sshArgs='-N -L%i:127.0.0.1:%i'
% (localPort, self.echoPort))
d.addCallback(self.assertEqual, 'test\n')
return d
def test_remoteToLocalForwarding(self):
"""
Test that we can use whatever client to forward a port from the server
to a port locally.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, 'test\n')
d = self.execute('', process,
sshArgs='-N -R %i:127.0.0.1:%i'
% (localPort, self.echoPort))
d.addCallback(self.assertEqual, 'test\n')
return d
# Conventionally there is a separate adapter object which provides ISession for
# the user, but making the user provide ISession directly works too. This isn't
# a full implementation of ISession though, just enough to make these tests
# pass.
@implementer(ISession)
class RekeyAvatar(ConchUser):
"""
This avatar implements a shell which sends 60 numbered lines to whatever
connects to it, then closes the session with a 0 exit status.
60 lines is selected as being enough to send more than 2kB of traffic, the
amount the client is configured to initiate a rekey after.
"""
def __init__(self):
ConchUser.__init__(self)
self.channelLookup['session'] = SSHSession
def openShell(self, transport):
"""
Write 60 lines of data to the transport, then exit.
"""
proto = protocol.Protocol()
proto.makeConnection(transport)
transport.makeConnection(wrapProtocol(proto))
# Send enough bytes to the connection so that a rekey is triggered in
# the client.
def write(counter):
i = counter()
if i == 60:
call.stop()
transport.session.conn.sendRequest(
transport.session, 'exit-status', '\x00\x00\x00\x00')
transport.loseConnection()
else:
transport.write("line #%02d\n" % (i,))
# The timing for this loop is an educated guess (and/or the result of
# experimentation) to exercise the case where a packet is generated
# mid-rekey. Since the other side of the connection is (so far) the
# OpenSSH command line client, there's no easy way to determine when the
# rekey has been initiated. If there were, then generating a packet
# immediately at that time would be a better way to test the
# functionality being tested here.
call = LoopingCall(write, count().next)
call.start(0.01)
def closed(self):
"""
Ignore the close of the session.
"""
class RekeyRealm:
"""
This realm gives out new L{RekeyAvatar} instances for any avatar request.
"""
def requestAvatar(self, avatarID, mind, *interfaces):
return interfaces[0], RekeyAvatar(), lambda: None
class RekeyTestsMixin(ConchServerSetupMixin):
"""
TestCase mixin which defines tests exercising L{SSHTransportBase}'s handling
of rekeying messages.
"""
realmFactory = RekeyRealm
def test_clientRekey(self):
"""
After a client-initiated rekey is completed, application data continues
to be passed over the SSH connection.
"""
process = ConchTestOpenSSHProcess()
d = self.execute("", process, '-o RekeyLimit=2K')
def finished(result):
self.assertEqual(
result,
'\n'.join(['line #%02d' % (i,) for i in range(60)]) + '\n')
d.addCallback(finished)
return d
class OpenSSHClientMixin:
if not which('ssh'):
skip = "no ssh command-line client available"
def execute(self, remoteCommand, process, sshArgs=''):
"""
Connects to the SSH server started in L{ConchServerSetupMixin.setUp} by
running the 'ssh' command line tool.
@type remoteCommand: str
@param remoteCommand: The command (with arguments) to run on the
remote end.
@type process: L{ConchTestOpenSSHProcess}
@type sshArgs: str
@param sshArgs: Arguments to pass to the 'ssh' process.
@return: L{defer.Deferred}
"""
process.deferred = defer.Deferred()
cmdline = ('ssh -2 -l testuser -p %i '
'-oUserKnownHostsFile=kh_test '
'-oPasswordAuthentication=no '
# Always use the RSA key, since that's the one in kh_test.
'-oHostKeyAlgorithms=ssh-rsa '
'-a '
'-i dsa_test ') + sshArgs + \
' 127.0.0.1 ' + remoteCommand
port = self.conchServer.getHost().port
cmds = (cmdline % port).split()
reactor.spawnProcess(process, "ssh", cmds)
return process.deferred
class OpenSSHClientForwardingTests(ForwardingMixin, OpenSSHClientMixin,
unittest.TestCase):
"""
Connection forwarding tests run against the OpenSSL command line client.
"""
def test_localToRemoteForwardingV6(self):
"""
Forwarding of arbitrary IPv6 TCP connections via SSH.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, 'test\n')
d = self.execute('', process,
sshArgs='-N -L%i:[::1]:%i'
% (localPort, self.echoPortV6))
d.addCallback(self.assertEqual, 'test\n')
return d
class OpenSSHClientRekeyTests(RekeyTestsMixin, OpenSSHClientMixin,
unittest.TestCase):
"""
Rekeying tests run against the OpenSSL command line client.
"""
class CmdLineClientTests(ForwardingMixin, unittest.TestCase):
"""
Connection forwarding tests run against the Conch command line client.
"""
if runtime.platformType == 'win32':
skip = "can't run cmdline client on win32"
def execute(self, remoteCommand, process, sshArgs=''):
"""
As for L{OpenSSHClientTestCase.execute}, except it runs the 'conch'
command line tool, not 'ssh'.
"""
process.deferred = defer.Deferred()
port = self.conchServer.getHost().port
cmd = ('-p %i -l testuser '
'--known-hosts kh_test '
'--user-authentications publickey '
'--host-key-algorithms ssh-rsa '
'-a '
'-i dsa_test '
'-v ') % port + sshArgs + \
' 127.0.0.1 ' + remoteCommand
cmds = _makeArgs(cmd.split())
log.msg(str(cmds))
env = os.environ.copy()
env['PYTHONPATH'] = os.pathsep.join(sys.path)
reactor.spawnProcess(process, sys.executable, cmds, env=env)
return process.deferred

View File

@ -0,0 +1,730 @@
# Copyright (c) 2007-2010 Twisted Matrix Laboratories.
# See LICENSE for details
"""
This module tests twisted.conch.ssh.connection.
"""
import struct
from twisted.conch import error
from twisted.conch.ssh import channel, common, connection
from twisted.trial import unittest
from twisted.conch.test import test_userauth
class TestChannel(channel.SSHChannel):
"""
A mocked-up version of twisted.conch.ssh.channel.SSHChannel.
@ivar gotOpen: True if channelOpen has been called.
@type gotOpen: C{bool}
@ivar specificData: the specific channel open data passed to channelOpen.
@type specificData: C{str}
@ivar openFailureReason: the reason passed to openFailed.
@type openFailed: C{error.ConchError}
@ivar inBuffer: a C{list} of strings received by the channel.
@type inBuffer: C{list}
@ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by
the channel.
@type extBuffer: C{list}
@ivar numberRequests: the number of requests that have been made to this
channel.
@type numberRequests: C{int}
@ivar gotEOF: True if the other side sent EOF.
@type gotEOF: C{bool}
@ivar gotOneClose: True if the other side closed the connection.
@type gotOneClose: C{bool}
@ivar gotClosed: True if the channel is closed.
@type gotClosed: C{bool}
"""
name = "TestChannel"
gotOpen = False
def logPrefix(self):
return "TestChannel %i" % self.id
def channelOpen(self, specificData):
"""
The channel is open. Set up the instance variables.
"""
self.gotOpen = True
self.specificData = specificData
self.inBuffer = []
self.extBuffer = []
self.numberRequests = 0
self.gotEOF = False
self.gotOneClose = False
self.gotClosed = False
def openFailed(self, reason):
"""
Opening the channel failed. Store the reason why.
"""
self.openFailureReason = reason
def request_test(self, data):
"""
A test request. Return True if data is 'data'.
@type data: C{str}
"""
self.numberRequests += 1
return data == 'data'
def dataReceived(self, data):
"""
Data was received. Store it in the buffer.
"""
self.inBuffer.append(data)
def extReceived(self, code, data):
"""
Extended data was received. Store it in the buffer.
"""
self.extBuffer.append((code, data))
def eofReceived(self):
"""
EOF was received. Remember it.
"""
self.gotEOF = True
def closeReceived(self):
"""
Close was received. Remember it.
"""
self.gotOneClose = True
def closed(self):
"""
The channel is closed. Rembember it.
"""
self.gotClosed = True
class TestAvatar:
"""
A mocked-up version of twisted.conch.avatar.ConchUser
"""
_ARGS_ERROR_CODE = 123
def lookupChannel(self, channelType, windowSize, maxPacket, data):
"""
The server wants us to return a channel. If the requested channel is
our TestChannel, return it, otherwise return None.
"""
if channelType == TestChannel.name:
return TestChannel(remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
data=data, avatar=self)
elif channelType == "conch-error-args":
# Raise a ConchError with backwards arguments to make sure the
# connection fixes it for us. This case should be deprecated and
# deleted eventually, but only after all of Conch gets the argument
# order right.
raise error.ConchError(
self._ARGS_ERROR_CODE, "error args in wrong order")
def gotGlobalRequest(self, requestType, data):
"""
The client has made a global request. If the global request is
'TestGlobal', return True. If the global request is 'TestData',
return True and the request-specific data we received. Otherwise,
return False.
"""
if requestType == 'TestGlobal':
return True
elif requestType == 'TestData':
return True, data
else:
return False
class TestConnection(connection.SSHConnection):
"""
A subclass of SSHConnection for testing.
@ivar channel: the current channel.
@type channel. C{TestChannel}
"""
def logPrefix(self):
return "TestConnection"
def global_TestGlobal(self, data):
"""
The other side made the 'TestGlobal' global request. Return True.
"""
return True
def global_Test_Data(self, data):
"""
The other side made the 'Test-Data' global request. Return True and
the data we received.
"""
return True, data
def channel_TestChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the TestChannel. Create a C{TestChannel}
instance, store it, and return it.
"""
self.channel = TestChannel(remoteWindow=windowSize,
remoteMaxPacket=maxPacket, data=data)
return self.channel
def channel_ErrorChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the ErrorChannel. Raise an exception.
"""
raise AssertionError('no such thing')
class ConnectionTests(unittest.TestCase):
if test_userauth.transport is None:
skip = "Cannot run without both PyCrypto and pyasn1"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
self.conn.serviceStarted()
def _openChannel(self, channel):
"""
Open the channel with the default connection.
"""
self.conn.openChannel(channel)
self.transport.packets = self.transport.packets[:-1]
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(struct.pack('>2L',
channel.id, 255) + '\x00\x02\x00\x00\x00\x00\x80\x00')
def tearDown(self):
self.conn.serviceStopped()
def test_linkAvatar(self):
"""
Test that the connection links itself to the avatar in the
transport.
"""
self.assertIs(self.transport.avatar.conn, self.conn)
def test_serviceStopped(self):
"""
Test that serviceStopped() closes any open channels.
"""
channel1 = TestChannel()
channel2 = TestChannel()
self.conn.openChannel(channel1)
self.conn.openChannel(channel2)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00' * 4)
self.assertTrue(channel1.gotOpen)
self.assertFalse(channel2.gotOpen)
self.conn.serviceStopped()
self.assertTrue(channel1.gotClosed)
def test_GLOBAL_REQUEST(self):
"""
Test that global request packets are dispatched to the global_*
methods and the return values are translated into success or failure
messages.
"""
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\xff')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_SUCCESS, '')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestData') + '\xff' +
'test data')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_SUCCESS, 'test data')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestBad') + '\xff')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_FAILURE, '')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\x00')
self.assertEqual(self.transport.packets, [])
def test_REQUEST_SUCCESS(self):
"""
Test that global request success packets cause the Deferred to be
called back.
"""
d = self.conn.sendGlobalRequest('request', 'data', True)
self.conn.ssh_REQUEST_SUCCESS('data')
def check(data):
self.assertEqual(data, 'data')
d.addCallback(check)
d.addErrback(self.fail)
return d
def test_REQUEST_FAILURE(self):
"""
Test that global request failure packets cause the Deferred to be
erred back.
"""
d = self.conn.sendGlobalRequest('request', 'data', True)
self.conn.ssh_REQUEST_FAILURE('data')
def check(f):
self.assertEqual(f.value.data, 'data')
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_OPEN(self):
"""
Test that open channel packets cause a channel to be created and
opened or a failure message to be returned.
"""
del self.transport.avatar
self.conn.ssh_CHANNEL_OPEN(common.NS('TestChannel') +
'\x00\x00\x00\x01' * 4)
self.assertTrue(self.conn.channel.gotOpen)
self.assertEqual(self.conn.channel.conn, self.conn)
self.assertEqual(self.conn.channel.data, '\x00\x00\x00\x01')
self.assertEqual(self.conn.channel.specificData, '\x00\x00\x00\x01')
self.assertEqual(self.conn.channel.remoteWindowLeft, 1)
self.assertEqual(self.conn.channel.remoteMaxPacket, 1)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_CONFIRMATION,
'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00'
'\x00\x00\x80\x00')])
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS('BadChannel') +
'\x00\x00\x00\x02' * 4)
self.flushLoggedErrors()
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
'\x00\x00\x00\x02\x00\x00\x00\x03' + common.NS(
'unknown channel') + common.NS(''))])
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS('ErrorChannel') +
'\x00\x00\x00\x02' * 4)
self.flushLoggedErrors()
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
'\x00\x00\x00\x02\x00\x00\x00\x02' + common.NS(
'unknown failure') + common.NS(''))])
def _lookupChannelErrorTest(self, code):
"""
Deliver a request for a channel open which will result in an exception
being raised during channel lookup. Assert that an error response is
delivered as a result.
"""
self.transport.avatar._ARGS_ERROR_CODE = code
self.conn.ssh_CHANNEL_OPEN(
common.NS('conch-error-args') + '\x00\x00\x00\x01' * 4)
errors = self.flushLoggedErrors(error.ConchError)
self.assertEqual(
len(errors), 1, "Expected one error, got: %r" % (errors,))
self.assertEqual(errors[0].value.args, (123, "error args in wrong order"))
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
# The response includes some bytes which identifying the
# associated request, as well as the error code (7b in hex) and
# the error message.
'\x00\x00\x00\x01\x00\x00\x00\x7b' + common.NS(
'error args in wrong order') + common.NS(''))])
def test_lookupChannelError(self):
"""
If a C{lookupChannel} implementation raises L{error.ConchError} with the
arguments in the wrong order, a C{MSG_CHANNEL_OPEN} failure is still
sent in response to the message.
This is a temporary work-around until L{error.ConchError} is given
better attributes and all of the Conch code starts constructing
instances of it properly. Eventually this functionality should be
deprecated and then removed.
"""
self._lookupChannelErrorTest(123)
def test_lookupChannelErrorLongCode(self):
"""
Like L{test_lookupChannelError}, but for the case where the failure code
is represented as a C{long} instead of a C{int}.
"""
self._lookupChannelErrorTest(123L)
def test_CHANNEL_OPEN_CONFIRMATION(self):
"""
Test that channel open confirmation packets cause the channel to be
notified that it's open.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00'*5)
self.assertEqual(channel.remoteWindowLeft, 0)
self.assertEqual(channel.remoteMaxPacket, 0)
self.assertEqual(channel.specificData, '\x00\x00\x00\x00')
self.assertEqual(self.conn.channelsToRemoteChannel[channel],
0)
self.assertEqual(self.conn.localToRemoteChannel[0], 0)
def test_CHANNEL_OPEN_FAILURE(self):
"""
Test that channel open failure packets cause the channel to be
notified that its opening failed.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_FAILURE('\x00\x00\x00\x00\x00\x00\x00'
'\x01' + common.NS('failure!'))
self.assertEqual(channel.openFailureReason.args, ('failure!', 1))
self.assertEqual(self.conn.channels.get(channel), None)
def test_CHANNEL_WINDOW_ADJUST(self):
"""
Test that channel window adjust messages add bytes to the channel
window.
"""
channel = TestChannel()
self._openChannel(channel)
oldWindowSize = channel.remoteWindowLeft
self.conn.ssh_CHANNEL_WINDOW_ADJUST('\x00\x00\x00\x00\x00\x00\x00'
'\x01')
self.assertEqual(channel.remoteWindowLeft, oldWindowSize + 1)
def test_CHANNEL_DATA(self):
"""
Test that channel data messages are passed up to the channel, or
cause the channel to be closed if the data is too large.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS('data'))
self.assertEqual(channel.inBuffer, ['data'])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
'\x00\x00\x00\x04')])
self.transport.packets = []
longData = 'a' * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS(longData))
self.assertEqual(channel.inBuffer, ['data'])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
channel = TestChannel()
self._openChannel(channel)
bigData = 'a' * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x01' + common.NS(bigData))
self.assertEqual(channel.inBuffer, [])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
def test_CHANNEL_EXTENDED_DATA(self):
"""
Test that channel extended data messages are passed up to the channel,
or cause the channel to be closed if they're too big.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
'\x00' + common.NS('data'))
self.assertEqual(channel.extBuffer, [(0, 'data')])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
'\x00\x00\x00\x04')])
self.transport.packets = []
longData = 'a' * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
'\x00' + common.NS(longData))
self.assertEqual(channel.extBuffer, [(0, 'data')])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
channel = TestChannel()
self._openChannel(channel)
bigData = 'a' * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x01\x00\x00\x00'
'\x00' + common.NS(bigData))
self.assertEqual(channel.extBuffer, [])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
def test_CHANNEL_EOF(self):
"""
Test that channel eof messages are passed up to the channel.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_EOF('\x00\x00\x00\x00')
self.assertTrue(channel.gotEOF)
def test_CHANNEL_CLOSE(self):
"""
Test that channel close messages are passed up to the channel. Also,
test that channel.close() is called if both sides are closed when this
message is received.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendClose(channel)
self.conn.ssh_CHANNEL_CLOSE('\x00\x00\x00\x00')
self.assertTrue(channel.gotOneClose)
self.assertTrue(channel.gotClosed)
def test_CHANNEL_REQUEST_success(self):
"""
Test that channel requests that succeed send MSG_CHANNEL_SUCCESS.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS('test')
+ '\x00')
self.assertEqual(channel.numberRequests, 1)
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
'test') + '\xff' + 'data')
def check(result):
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_SUCCESS, '\x00\x00\x00\xff')])
d.addCallback(check)
return d
def test_CHANNEL_REQUEST_failure(self):
"""
Test that channel requests that fail send MSG_CHANNEL_FAILURE.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
'test') + '\xff')
def check(result):
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_FAILURE, '\x00\x00\x00\xff'
)])
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_REQUEST_SUCCESS(self):
"""
Test that channel request success messages cause the Deferred to be
called back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, 'test', 'data', True)
self.conn.ssh_CHANNEL_SUCCESS('\x00\x00\x00\x00')
def check(result):
self.assertTrue(result)
return d
def test_CHANNEL_REQUEST_FAILURE(self):
"""
Test that channel request failure messages cause the Deferred to be
erred back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, 'test', '', True)
self.conn.ssh_CHANNEL_FAILURE('\x00\x00\x00\x00')
def check(result):
self.assertEqual(result.value.value, 'channel request failed')
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_sendGlobalRequest(self):
"""
Test that global request messages are sent in the right format.
"""
d = self.conn.sendGlobalRequest('wantReply', 'data', True)
# must be added to prevent errbacking during teardown
d.addErrback(lambda failure: None)
self.conn.sendGlobalRequest('noReply', '', False)
self.assertEqual(self.transport.packets,
[(connection.MSG_GLOBAL_REQUEST, common.NS('wantReply') +
'\xffdata'),
(connection.MSG_GLOBAL_REQUEST, common.NS('noReply') +
'\x00')])
self.assertEqual(self.conn.deferreds, {'global':[d]})
def test_openChannel(self):
"""
Test that open channel messages are sent in the right format.
"""
channel = TestChannel()
self.conn.openChannel(channel, 'aaaa')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN, common.NS('TestChannel') +
'\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa')])
self.assertEqual(channel.id, 0)
self.assertEqual(self.conn.localChannelID, 1)
def test_sendRequest(self):
"""
Test that channel request messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, 'test', 'test', True)
# needed to prevent errbacks during teardown.
d.addErrback(lambda failure: None)
self.conn.sendRequest(channel, 'test2', '', False)
channel.localClosed = True # emulate sending a close message
self.conn.sendRequest(channel, 'test3', '', True)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
common.NS('test') + '\x01test'),
(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
common.NS('test2') + '\x00')])
self.assertEqual(self.conn.deferreds[0], [d])
def test_adjustWindow(self):
"""
Test that channel window adjust messages cause bytes to be added
to the window.
"""
channel = TestChannel(localWindow=5)
self._openChannel(channel)
channel.localWindowLeft = 0
self.conn.adjustWindow(channel, 1)
self.assertEqual(channel.localWindowLeft, 1)
channel.localClosed = True
self.conn.adjustWindow(channel, 2)
self.assertEqual(channel.localWindowLeft, 1)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
'\x00\x00\x00\x01')])
def test_sendData(self):
"""
Test that channel data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendData(channel, 'a')
channel.localClosed = True
self.conn.sendData(channel, 'b')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_DATA, '\x00\x00\x00\xff' +
common.NS('a'))])
def test_sendExtendedData(self):
"""
Test that channel extended data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendExtendedData(channel, 1, 'test')
channel.localClosed = True
self.conn.sendExtendedData(channel, 2, 'test2')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EXTENDED_DATA, '\x00\x00\x00\xff' +
'\x00\x00\x00\x01' + common.NS('test'))])
def test_sendEOF(self):
"""
Test that channel EOF messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendEOF(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
channel.localClosed = True
self.conn.sendEOF(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
def test_sendClose(self):
"""
Test that channel close messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendClose(channel)
self.assertTrue(channel.localClosed)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
self.conn.sendClose(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
channel2 = TestChannel()
self._openChannel(channel2)
channel2.remoteClosed = True
self.conn.sendClose(channel2)
self.assertTrue(channel2.gotClosed)
def test_getChannelWithAvatar(self):
"""
Test that getChannel dispatches to the avatar when an avatar is
present. Correct functioning without the avatar is verified in
test_CHANNEL_OPEN.
"""
channel = self.conn.getChannel('TestChannel', 50, 30, 'data')
self.assertEqual(channel.data, 'data')
self.assertEqual(channel.remoteWindowLeft, 50)
self.assertEqual(channel.remoteMaxPacket, 30)
self.assertRaises(error.ConchError, self.conn.getChannel,
'BadChannel', 50, 30, 'data')
def test_gotGlobalRequestWithoutAvatar(self):
"""
Test that gotGlobalRequests dispatches to global_* without an avatar.
"""
del self.transport.avatar
self.assertTrue(self.conn.gotGlobalRequest('TestGlobal', 'data'))
self.assertEqual(self.conn.gotGlobalRequest('Test-Data', 'data'),
(True, 'data'))
self.assertFalse(self.conn.gotGlobalRequest('BadGlobal', 'data'))
def test_channelClosedCausesLeftoverChannelDeferredsToErrback(self):
"""
Whenever an SSH channel gets closed any Deferred that was returned by a
sendRequest() on its parent connection must be errbacked.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(
channel, "dummyrequest", "dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.channelClosed(channel)
return d
class CleanConnectionShutdownTests(unittest.TestCase):
"""
Check whether correct cleanup is performed on connection shutdown.
"""
if test_userauth.transport is None:
skip = "Cannot run without both PyCrypto and pyasn1"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
def test_serviceStoppedCausesLeftoverGlobalDeferredsToErrback(self):
"""
Once the service is stopped any leftover global deferred returned by
a sendGlobalRequest() call must be errbacked.
"""
self.conn.serviceStarted()
d = self.conn.sendGlobalRequest(
"dummyrequest", "dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.serviceStopped()
return d

View File

@ -0,0 +1,169 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.client.default}.
"""
from twisted.python.reflect import requireModule
if requireModule('Crypto.Cipher.DES3') and requireModule('pyasn1'):
from twisted.conch.client.agent import SSHAgentClient
from twisted.conch.client.default import SSHUserAuthClient
from twisted.conch.client.options import ConchOptions
from twisted.conch.ssh.keys import Key
else:
skip = "PyCrypto and PyASN1 required for twisted.conch.client.default."
from twisted.trial.unittest import TestCase
from twisted.python.filepath import FilePath
from twisted.conch.test import keydata
from twisted.test.proto_helpers import StringTransport
class SSHUserAuthClientTests(TestCase):
"""
Tests for L{SSHUserAuthClient}.
@type rsaPublic: L{Key}
@ivar rsaPublic: A public RSA key.
"""
def setUp(self):
self.rsaPublic = Key.fromString(keydata.publicRSA_openssh)
self.tmpdir = FilePath(self.mktemp())
self.tmpdir.makedirs()
self.rsaFile = self.tmpdir.child('id_rsa')
self.rsaFile.setContent(keydata.privateRSA_openssh)
self.tmpdir.child('id_rsa.pub').setContent(keydata.publicRSA_openssh)
def test_signDataWithAgent(self):
"""
When connected to an agent, L{SSHUserAuthClient} can use it to
request signatures of particular data with a particular L{Key}.
"""
client = SSHUserAuthClient("user", ConchOptions(), None)
agent = SSHAgentClient()
transport = StringTransport()
agent.makeConnection(transport)
client.keyAgent = agent
cleartext = "Sign here"
client.signData(self.rsaPublic, cleartext)
self.assertEqual(
transport.value(),
"\x00\x00\x00\x8b\r\x00\x00\x00u" + self.rsaPublic.blob() +
"\x00\x00\x00\t" + cleartext +
"\x00\x00\x00\x00")
def test_agentGetPublicKey(self):
"""
L{SSHUserAuthClient} looks up public keys from the agent using the
L{SSHAgentClient} class. That L{SSHAgentClient.getPublicKey} returns a
L{Key} object with one of the public keys in the agent. If no more
keys are present, it returns C{None}.
"""
agent = SSHAgentClient()
agent.blobs = [self.rsaPublic.blob()]
key = agent.getPublicKey()
self.assertEqual(key.isPublic(), True)
self.assertEqual(key, self.rsaPublic)
self.assertEqual(agent.getPublicKey(), None)
def test_getPublicKeyFromFile(self):
"""
L{SSHUserAuthClient.getPublicKey()} is able to get a public key from
the first file described by its options' C{identitys} list, and return
the corresponding public L{Key} object.
"""
options = ConchOptions()
options.identitys = [self.rsaFile.path]
client = SSHUserAuthClient("user", options, None)
key = client.getPublicKey()
self.assertEqual(key.isPublic(), True)
self.assertEqual(key, self.rsaPublic)
def test_getPublicKeyAgentFallback(self):
"""
If an agent is present, but doesn't return a key,
L{SSHUserAuthClient.getPublicKey} continue with the normal key lookup.
"""
options = ConchOptions()
options.identitys = [self.rsaFile.path]
agent = SSHAgentClient()
client = SSHUserAuthClient("user", options, None)
client.keyAgent = agent
key = client.getPublicKey()
self.assertEqual(key.isPublic(), True)
self.assertEqual(key, self.rsaPublic)
def test_getPublicKeyBadKeyError(self):
"""
If L{keys.Key.fromFile} raises a L{keys.BadKeyError}, the
L{SSHUserAuthClient.getPublicKey} tries again to get a public key by
calling itself recursively.
"""
options = ConchOptions()
self.tmpdir.child('id_dsa.pub').setContent(keydata.publicDSA_openssh)
dsaFile = self.tmpdir.child('id_dsa')
dsaFile.setContent(keydata.privateDSA_openssh)
options.identitys = [self.rsaFile.path, dsaFile.path]
self.tmpdir.child('id_rsa.pub').setContent('not a key!')
client = SSHUserAuthClient("user", options, None)
key = client.getPublicKey()
self.assertEqual(key.isPublic(), True)
self.assertEqual(key, Key.fromString(keydata.publicDSA_openssh))
self.assertEqual(client.usedFiles, [self.rsaFile.path, dsaFile.path])
def test_getPrivateKey(self):
"""
L{SSHUserAuthClient.getPrivateKey} will load a private key from the
last used file populated by L{SSHUserAuthClient.getPublicKey}, and
return a L{Deferred} which fires with the corresponding private L{Key}.
"""
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
options = ConchOptions()
options.identitys = [self.rsaFile.path]
client = SSHUserAuthClient("user", options, None)
# Populate the list of used files
client.getPublicKey()
def _cbGetPrivateKey(key):
self.assertEqual(key.isPublic(), False)
self.assertEqual(key, rsaPrivate)
return client.getPrivateKey().addCallback(_cbGetPrivateKey)
def test_getPrivateKeyPassphrase(self):
"""
L{SSHUserAuthClient} can get a private key from a file, and return a
Deferred called back with a private L{Key} object, even if the key is
encrypted.
"""
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
passphrase = 'this is the passphrase'
self.rsaFile.setContent(rsaPrivate.toString('openssh', passphrase))
options = ConchOptions()
options.identitys = [self.rsaFile.path]
client = SSHUserAuthClient("user", options, None)
# Populate the list of used files
client.getPublicKey()
def _getPassword(prompt):
self.assertEqual(prompt,
"Enter passphrase for key '%s': " % (
self.rsaFile.path,))
return passphrase
def _cbGetPrivateKey(key):
self.assertEqual(key.isPublic(), False)
self.assertEqual(key, rsaPrivate)
self.patch(client, '_getPassword', _getPassword)
return client.getPrivateKey().addCallback(_cbGetPrivateKey)

View File

@ -0,0 +1,770 @@
# -*- test-case-name: twisted.conch.test.test_filetransfer -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE file for details.
"""
Tests for L{twisted.conch.ssh.filetransfer}.
"""
import os
import re
import struct
from twisted.trial import unittest
try:
from twisted.conch import unix
unix # shut up pyflakes
except ImportError:
unix = None
from twisted.conch import avatar
from twisted.conch.ssh import common, connection, filetransfer, session
from twisted.internet import defer
from twisted.protocols import loopback
from twisted.python import components
class TestAvatar(avatar.ConchUser):
def __init__(self):
avatar.ConchUser.__init__(self)
self.channelLookup['session'] = session.SSHSession
self.subsystemLookup['sftp'] = filetransfer.FileTransferServer
def _runAsUser(self, f, *args, **kw):
try:
f = iter(f)
except TypeError:
f = [(f, args, kw)]
for i in f:
func = i[0]
args = len(i)>1 and i[1] or ()
kw = len(i)>2 and i[2] or {}
r = func(*args, **kw)
return r
class FileTransferTestAvatar(TestAvatar):
def __init__(self, homeDir):
TestAvatar.__init__(self)
self.homeDir = homeDir
def getHomeDir(self):
return os.path.join(os.getcwd(), self.homeDir)
class ConchSessionForTestAvatar:
def __init__(self, avatar):
self.avatar = avatar
if unix:
if not hasattr(unix, 'SFTPServerForUnixConchUser'):
# unix should either be a fully working module, or None. I'm not sure
# how this happens, but on win32 it does. Try to cope. --spiv.
import warnings
warnings.warn(("twisted.conch.unix imported %r, "
"but doesn't define SFTPServerForUnixConchUser'")
% (unix,))
unix = None
else:
class FileTransferForTestAvatar(unix.SFTPServerForUnixConchUser):
def gotVersion(self, version, otherExt):
return {'conchTest' : 'ext data'}
def extendedRequest(self, extName, extData):
if extName == 'testExtendedRequest':
return 'bar'
raise NotImplementedError
components.registerAdapter(FileTransferForTestAvatar,
TestAvatar,
filetransfer.ISFTPServer)
class SFTPTestBase(unittest.TestCase):
def setUp(self):
self.testDir = self.mktemp()
# Give the testDir another level so we can safely "cd .." from it in
# tests.
self.testDir = os.path.join(self.testDir, 'extra')
os.makedirs(os.path.join(self.testDir, 'testDirectory'))
f = file(os.path.join(self.testDir, 'testfile1'),'w')
f.write('a'*10+'b'*10)
f.write(file('/dev/urandom').read(1024*64)) # random data
os.chmod(os.path.join(self.testDir, 'testfile1'), 0644)
file(os.path.join(self.testDir, 'testRemoveFile'), 'w').write('a')
file(os.path.join(self.testDir, 'testRenameFile'), 'w').write('a')
file(os.path.join(self.testDir, '.testHiddenFile'), 'w').write('a')
class OurServerOurClientTests(SFTPTestBase):
if not unix:
skip = "can't run on non-posix computers"
def setUp(self):
SFTPTestBase.setUp(self)
self.avatar = FileTransferTestAvatar(self.testDir)
self.server = filetransfer.FileTransferServer(avatar=self.avatar)
clientTransport = loopback.LoopbackRelay(self.server)
self.client = filetransfer.FileTransferClient()
self._serverVersion = None
self._extData = None
def _(serverVersion, extData):
self._serverVersion = serverVersion
self._extData = extData
self.client.gotServerVersion = _
serverTransport = loopback.LoopbackRelay(self.client)
self.client.makeConnection(clientTransport)
self.server.makeConnection(serverTransport)
self.clientTransport = clientTransport
self.serverTransport = serverTransport
self._emptyBuffers()
def _emptyBuffers(self):
while self.serverTransport.buffer or self.clientTransport.buffer:
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
def tearDown(self):
self.serverTransport.loseConnection()
self.clientTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
def testServerVersion(self):
self.assertEqual(self._serverVersion, 3)
self.assertEqual(self._extData, {'conchTest' : 'ext data'})
def test_interface_implementation(self):
"""
It implements the ISFTPServer interface.
"""
self.assertTrue(
filetransfer.ISFTPServer.providedBy(self.server.client),
"ISFTPServer not provided by %r" % (self.server.client,))
def test_openedFileClosedWithConnection(self):
"""
A file opened with C{openFile} is close when the connection is lost.
"""
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
oldClose = os.close
closed = []
def close(fd):
closed.append(fd)
oldClose(fd)
self.patch(os, "close", close)
def _fileOpened(openFile):
fd = self.server.openFiles[openFile.handle[4:]].fd
self.serverTransport.loseConnection()
self.clientTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
self.assertEqual(self.server.openFiles, {})
self.assertIn(fd, closed)
d.addCallback(_fileOpened)
return d
def test_openedDirectoryClosedWithConnection(self):
"""
A directory opened with C{openDirectory} is close when the connection
is lost.
"""
d = self.client.openDirectory('')
self._emptyBuffers()
def _getFiles(openDir):
self.serverTransport.loseConnection()
self.clientTransport.loseConnection()
self.serverTransport.clearBuffer()
self.clientTransport.clearBuffer()
self.assertEqual(self.server.openDirs, {})
d.addCallback(_getFiles)
return d
def testOpenFileIO(self):
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
def _fileOpened(openFile):
self.assertEqual(openFile, filetransfer.ISFTPFile(openFile))
d = _readChunk(openFile)
d.addCallback(_writeChunk, openFile)
return d
def _readChunk(openFile):
d = openFile.readChunk(0, 20)
self._emptyBuffers()
d.addCallback(self.assertEqual, 'a'*10 + 'b'*10)
return d
def _writeChunk(_, openFile):
d = openFile.writeChunk(20, 'c'*10)
self._emptyBuffers()
d.addCallback(_readChunk2, openFile)
return d
def _readChunk2(_, openFile):
d = openFile.readChunk(0, 30)
self._emptyBuffers()
d.addCallback(self.assertEqual, 'a'*10 + 'b'*10 + 'c'*10)
return d
d.addCallback(_fileOpened)
return d
def testClosedFileGetAttrs(self):
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
def _getAttrs(_, openFile):
d = openFile.getAttrs()
self._emptyBuffers()
return d
def _err(f):
self.flushLoggedErrors()
return f
def _close(openFile):
d = openFile.close()
self._emptyBuffers()
d.addCallback(_getAttrs, openFile)
d.addErrback(_err)
return self.assertFailure(d, filetransfer.SFTPError)
d.addCallback(_close)
return d
def testOpenFileAttributes(self):
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
def _getAttrs(openFile):
d = openFile.getAttrs()
self._emptyBuffers()
d.addCallback(_getAttrs2)
return d
def _getAttrs2(attrs1):
d = self.client.getAttrs('testfile1')
self._emptyBuffers()
d.addCallback(self.assertEqual, attrs1)
return d
return d.addCallback(_getAttrs)
def testOpenFileSetAttrs(self):
# XXX test setAttrs
# Ok, how about this for a start? It caught a bug :) -- spiv.
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {})
self._emptyBuffers()
def _getAttrs(openFile):
d = openFile.getAttrs()
self._emptyBuffers()
d.addCallback(_setAttrs)
return d
def _setAttrs(attrs):
attrs['atime'] = 0
d = self.client.setAttrs('testfile1', attrs)
self._emptyBuffers()
d.addCallback(_getAttrs2)
d.addCallback(self.assertEqual, attrs)
return d
def _getAttrs2(_):
d = self.client.getAttrs('testfile1')
self._emptyBuffers()
return d
d.addCallback(_getAttrs)
return d
def test_openFileExtendedAttributes(self):
"""
Check that L{filetransfer.FileTransferClient.openFile} can send
extended attributes, that should be extracted server side. By default,
they are ignored, so we just verify they are correctly parsed.
"""
savedAttributes = {}
oldOpenFile = self.server.client.openFile
def openFile(filename, flags, attrs):
savedAttributes.update(attrs)
return oldOpenFile(filename, flags, attrs)
self.server.client.openFile = openFile
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
filetransfer.FXF_WRITE, {"ext_foo": "bar"})
self._emptyBuffers()
def check(ign):
self.assertEqual(savedAttributes, {"ext_foo": "bar"})
return d.addCallback(check)
def testRemoveFile(self):
d = self.client.getAttrs("testRemoveFile")
self._emptyBuffers()
def _removeFile(ignored):
d = self.client.removeFile("testRemoveFile")
self._emptyBuffers()
return d
d.addCallback(_removeFile)
d.addCallback(_removeFile)
return self.assertFailure(d, filetransfer.SFTPError)
def testRenameFile(self):
d = self.client.getAttrs("testRenameFile")
self._emptyBuffers()
def _rename(attrs):
d = self.client.renameFile("testRenameFile", "testRenamedFile")
self._emptyBuffers()
d.addCallback(_testRenamed, attrs)
return d
def _testRenamed(_, attrs):
d = self.client.getAttrs("testRenamedFile")
self._emptyBuffers()
d.addCallback(self.assertEqual, attrs)
return d.addCallback(_rename)
def testDirectoryBad(self):
d = self.client.getAttrs("testMakeDirectory")
self._emptyBuffers()
return self.assertFailure(d, filetransfer.SFTPError)
def testDirectoryCreation(self):
d = self.client.makeDirectory("testMakeDirectory", {})
self._emptyBuffers()
def _getAttrs(_):
d = self.client.getAttrs("testMakeDirectory")
self._emptyBuffers()
return d
# XXX not until version 4/5
# self.assertEqual(filetransfer.FILEXFER_TYPE_DIRECTORY&attrs['type'],
# filetransfer.FILEXFER_TYPE_DIRECTORY)
def _removeDirectory(_):
d = self.client.removeDirectory("testMakeDirectory")
self._emptyBuffers()
return d
d.addCallback(_getAttrs)
d.addCallback(_removeDirectory)
d.addCallback(_getAttrs)
return self.assertFailure(d, filetransfer.SFTPError)
def testOpenDirectory(self):
d = self.client.openDirectory('')
self._emptyBuffers()
files = []
def _getFiles(openDir):
def append(f):
files.append(f)
return openDir
d = defer.maybeDeferred(openDir.next)
self._emptyBuffers()
d.addCallback(append)
d.addCallback(_getFiles)
d.addErrback(_close, openDir)
return d
def _checkFiles(ignored):
fs = list(zip(*files)[0])
fs.sort()
self.assertEqual(fs,
['.testHiddenFile', 'testDirectory',
'testRemoveFile', 'testRenameFile',
'testfile1'])
def _close(_, openDir):
d = openDir.close()
self._emptyBuffers()
return d
d.addCallback(_getFiles)
d.addCallback(_checkFiles)
return d
def testLinkDoesntExist(self):
d = self.client.getAttrs('testLink')
self._emptyBuffers()
return self.assertFailure(d, filetransfer.SFTPError)
def testLinkSharesAttrs(self):
d = self.client.makeLink('testLink', 'testfile1')
self._emptyBuffers()
def _getFirstAttrs(_):
d = self.client.getAttrs('testLink', 1)
self._emptyBuffers()
return d
def _getSecondAttrs(firstAttrs):
d = self.client.getAttrs('testfile1')
self._emptyBuffers()
d.addCallback(self.assertEqual, firstAttrs)
return d
d.addCallback(_getFirstAttrs)
return d.addCallback(_getSecondAttrs)
def testLinkPath(self):
d = self.client.makeLink('testLink', 'testfile1')
self._emptyBuffers()
def _readLink(_):
d = self.client.readLink('testLink')
self._emptyBuffers()
d.addCallback(self.assertEqual,
os.path.join(os.getcwd(), self.testDir, 'testfile1'))
return d
def _realPath(_):
d = self.client.realPath('testLink')
self._emptyBuffers()
d.addCallback(self.assertEqual,
os.path.join(os.getcwd(), self.testDir, 'testfile1'))
return d
d.addCallback(_readLink)
d.addCallback(_realPath)
return d
def testExtendedRequest(self):
d = self.client.extendedRequest('testExtendedRequest', 'foo')
self._emptyBuffers()
d.addCallback(self.assertEqual, 'bar')
d.addCallback(self._cbTestExtendedRequest)
return d
def _cbTestExtendedRequest(self, ignored):
d = self.client.extendedRequest('testBadRequest', '')
self._emptyBuffers()
return self.assertFailure(d, NotImplementedError)
class FakeConn:
def sendClose(self, channel):
pass
class FileTransferCloseTests(unittest.TestCase):
if not unix:
skip = "can't run on non-posix computers"
def setUp(self):
self.avatar = TestAvatar()
def buildServerConnection(self):
# make a server connection
conn = connection.SSHConnection()
# server connections have a 'self.transport.avatar'.
class DummyTransport:
def __init__(self):
self.transport = self
def sendPacket(self, kind, data):
pass
def logPrefix(self):
return 'dummy transport'
conn.transport = DummyTransport()
conn.transport.avatar = self.avatar
return conn
def interceptConnectionLost(self, sftpServer):
self.connectionLostFired = False
origConnectionLost = sftpServer.connectionLost
def connectionLost(reason):
self.connectionLostFired = True
origConnectionLost(reason)
sftpServer.connectionLost = connectionLost
def assertSFTPConnectionLost(self):
self.assertTrue(self.connectionLostFired,
"sftpServer's connectionLost was not called")
def test_sessionClose(self):
"""
Closing a session should notify an SFTP subsystem launched by that
session.
"""
# make a session
testSession = session.SSHSession(conn=FakeConn(), avatar=self.avatar)
# start an SFTP subsystem on the session
testSession.request_subsystem(common.NS('sftp'))
sftpServer = testSession.client.transport.proto
# intercept connectionLost so we can check that it's called
self.interceptConnectionLost(sftpServer)
# close session
testSession.closeReceived()
self.assertSFTPConnectionLost()
def test_clientClosesChannelOnConnnection(self):
"""
A client sending CHANNEL_CLOSE should trigger closeReceived on the
associated channel instance.
"""
conn = self.buildServerConnection()
# somehow get a session
packet = common.NS('session') + struct.pack('>L', 0) * 3
conn.ssh_CHANNEL_OPEN(packet)
sessionChannel = conn.channels[0]
sessionChannel.request_subsystem(common.NS('sftp'))
sftpServer = sessionChannel.client.transport.proto
self.interceptConnectionLost(sftpServer)
# intercept closeReceived
self.interceptConnectionLost(sftpServer)
# close the connection
conn.ssh_CHANNEL_CLOSE(struct.pack('>L', 0))
self.assertSFTPConnectionLost()
def test_stopConnectionServiceClosesChannel(self):
"""
Closing an SSH connection should close all sessions within it.
"""
conn = self.buildServerConnection()
# somehow get a session
packet = common.NS('session') + struct.pack('>L', 0) * 3
conn.ssh_CHANNEL_OPEN(packet)
sessionChannel = conn.channels[0]
sessionChannel.request_subsystem(common.NS('sftp'))
sftpServer = sessionChannel.client.transport.proto
self.interceptConnectionLost(sftpServer)
# close the connection
conn.serviceStopped()
self.assertSFTPConnectionLost()
class ConstantsTests(unittest.TestCase):
"""
Tests for the constants used by the SFTP protocol implementation.
@ivar filexferSpecExcerpts: Excerpts from the
draft-ietf-secsh-filexfer-02.txt (draft) specification of the SFTP
protocol. There are more recent drafts of the specification, but this
one describes version 3, which is what conch (and OpenSSH) implements.
"""
filexferSpecExcerpts = [
"""
The following values are defined for packet types.
#define SSH_FXP_INIT 1
#define SSH_FXP_VERSION 2
#define SSH_FXP_OPEN 3
#define SSH_FXP_CLOSE 4
#define SSH_FXP_READ 5
#define SSH_FXP_WRITE 6
#define SSH_FXP_LSTAT 7
#define SSH_FXP_FSTAT 8
#define SSH_FXP_SETSTAT 9
#define SSH_FXP_FSETSTAT 10
#define SSH_FXP_OPENDIR 11
#define SSH_FXP_READDIR 12
#define SSH_FXP_REMOVE 13
#define SSH_FXP_MKDIR 14
#define SSH_FXP_RMDIR 15
#define SSH_FXP_REALPATH 16
#define SSH_FXP_STAT 17
#define SSH_FXP_RENAME 18
#define SSH_FXP_READLINK 19
#define SSH_FXP_SYMLINK 20
#define SSH_FXP_STATUS 101
#define SSH_FXP_HANDLE 102
#define SSH_FXP_DATA 103
#define SSH_FXP_NAME 104
#define SSH_FXP_ATTRS 105
#define SSH_FXP_EXTENDED 200
#define SSH_FXP_EXTENDED_REPLY 201
Additional packet types should only be defined if the protocol
version number (see Section ``Protocol Initialization'') is
incremented, and their use MUST be negotiated using the version
number. However, the SSH_FXP_EXTENDED and SSH_FXP_EXTENDED_REPLY
packets can be used to implement vendor-specific extensions. See
Section ``Vendor-Specific-Extensions'' for more details.
""",
"""
The flags bits are defined to have the following values:
#define SSH_FILEXFER_ATTR_SIZE 0x00000001
#define SSH_FILEXFER_ATTR_UIDGID 0x00000002
#define SSH_FILEXFER_ATTR_PERMISSIONS 0x00000004
#define SSH_FILEXFER_ATTR_ACMODTIME 0x00000008
#define SSH_FILEXFER_ATTR_EXTENDED 0x80000000
""",
"""
The `pflags' field is a bitmask. The following bits have been
defined.
#define SSH_FXF_READ 0x00000001
#define SSH_FXF_WRITE 0x00000002
#define SSH_FXF_APPEND 0x00000004
#define SSH_FXF_CREAT 0x00000008
#define SSH_FXF_TRUNC 0x00000010
#define SSH_FXF_EXCL 0x00000020
""",
"""
Currently, the following values are defined (other values may be
defined by future versions of this protocol):
#define SSH_FX_OK 0
#define SSH_FX_EOF 1
#define SSH_FX_NO_SUCH_FILE 2
#define SSH_FX_PERMISSION_DENIED 3
#define SSH_FX_FAILURE 4
#define SSH_FX_BAD_MESSAGE 5
#define SSH_FX_NO_CONNECTION 6
#define SSH_FX_CONNECTION_LOST 7
#define SSH_FX_OP_UNSUPPORTED 8
"""]
def test_constantsAgainstSpec(self):
"""
The constants used by the SFTP protocol implementation match those
found by searching through the spec.
"""
constants = {}
for excerpt in self.filexferSpecExcerpts:
for line in excerpt.splitlines():
m = re.match('^\s*#define SSH_([A-Z_]+)\s+([0-9x]*)\s*$', line)
if m:
constants[m.group(1)] = long(m.group(2), 0)
self.assertTrue(
len(constants) > 0, "No constants found (the test must be buggy).")
for k, v in constants.items():
self.assertEqual(v, getattr(filetransfer, k))
class RawPacketDataTests(unittest.TestCase):
"""
Tests for L{filetransfer.FileTransferClient} which explicitly craft certain
less common protocol messages to exercise their handling.
"""
def setUp(self):
self.ftc = filetransfer.FileTransferClient()
def test_packetSTATUS(self):
"""
A STATUS packet containing a result code, a message, and a language is
parsed to produce the result of an outstanding request L{Deferred}.
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
of the SFTP Internet-Draft.
"""
d = defer.Deferred()
d.addCallback(self._cbTestPacketSTATUS)
self.ftc.openRequests[1] = d
data = struct.pack('!LL', 1, filetransfer.FX_OK) + common.NS('msg') + common.NS('lang')
self.ftc.packet_STATUS(data)
return d
def _cbTestPacketSTATUS(self, result):
"""
Assert that the result is a two-tuple containing the message and
language from the STATUS packet.
"""
self.assertEqual(result[0], 'msg')
self.assertEqual(result[1], 'lang')
def test_packetSTATUSShort(self):
"""
A STATUS packet containing only a result code can also be parsed to
produce the result of an outstanding request L{Deferred}. Such packets
are sent by some SFTP implementations, though not strictly legal.
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
of the SFTP Internet-Draft.
"""
d = defer.Deferred()
d.addCallback(self._cbTestPacketSTATUSShort)
self.ftc.openRequests[1] = d
data = struct.pack('!LL', 1, filetransfer.FX_OK)
self.ftc.packet_STATUS(data)
return d
def _cbTestPacketSTATUSShort(self, result):
"""
Assert that the result is a two-tuple containing empty strings, since
the STATUS packet had neither a message nor a language.
"""
self.assertEqual(result[0], '')
self.assertEqual(result[1], '')
def test_packetSTATUSWithoutLang(self):
"""
A STATUS packet containing a result code and a message but no language
can also be parsed to produce the result of an outstanding request
L{Deferred}. Such packets are sent by some SFTP implementations, though
not strictly legal.
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
of the SFTP Internet-Draft.
"""
d = defer.Deferred()
d.addCallback(self._cbTestPacketSTATUSWithoutLang)
self.ftc.openRequests[1] = d
data = struct.pack('!LL', 1, filetransfer.FX_OK) + common.NS('msg')
self.ftc.packet_STATUS(data)
return d
def _cbTestPacketSTATUSWithoutLang(self, result):
"""
Assert that the result is a two-tuple containing the message from the
STATUS packet and an empty string, since the language was missing.
"""
self.assertEqual(result[0], 'msg')
self.assertEqual(result[1], '')

View File

@ -0,0 +1,82 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.ssh.forwarding}.
"""
from socket import AF_INET6
from twisted.conch.ssh import forwarding
from twisted.internet import defer
from twisted.internet.address import IPv6Address
from twisted.trial import unittest
from twisted.test.proto_helpers import MemoryReactorClock, StringTransport
class TestSSHConnectForwardingChannel(unittest.TestCase):
"""
Unit and integration tests for L{SSHConnectForwardingChannel}.
"""
def patchHostnameEndpointResolver(self, request, response):
"""
Patch L{forwarding.HostnameEndpoint} to respond with a predefined
answer for DNS resolver requests.
@param request: Tupple of requested (hostname, port).
@type request: C{tuppe}.
@param response: Tupple of (family, address) to respond the the
associated C{request}.
@type response: C{tuppe}.
"""
hostname, port = request
family, address = response
riggerResolver = {('fwd.example.org', 1234): (
AF_INET6, None, None, None, ('::1', 1234))}
def riggedResolution(this, host, port):
return defer.succeed([riggerResolver[(host, port)]])
self.patch(
forwarding.HostnameEndpoint, '_nameResolution', riggedResolution)
def makeTCPConnection(self, reactor):
"""
Fake that connection was established for first connectTCP request made
on C{reactor}.
@param reactor: Reactor on which to fake the connection.
@type reactor: A reactor.
"""
factory = reactor.tcpClients[0][2]
connector = reactor.connectors[0]
protocol = factory.buildProtocol(None)
transport = StringTransport(peerAddress=connector.getDestination())
protocol.makeConnection(transport)
def test_channelOpenHostnameRequests(self):
"""
When a hostname is sent as part of forwarding requests, it
is resolved using HostnameEndpoint's resolver.
"""
sut = forwarding.SSHConnectForwardingChannel(
hostport=('fwd.example.org', 1234))
# Patch channel and resolver to not touch the network.
sut._reactor = MemoryReactorClock()
self.patchHostnameEndpointResolver(
request=('fwd.example.org', 1234),
response=(AF_INET6 ,'::1'),
)
sut.channelOpen(None)
self.makeTCPConnection(sut._reactor)
self.successResultOf(sut._channelOpenDeferred)
# Channel is connected using a forwarding client to the resolved
# address of the requested host.
self.assertTrue(isinstance(sut.client, forwarding.SSHForwardingClient))
self.assertEqual(
IPv6Address('TCP', '::1', 1234), sut.client.transport.getPeer())

View File

@ -0,0 +1,614 @@
# -*- test-case-name: twisted.conch.test.test_helper -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.conch.insults import helper
from twisted.conch.insults.insults import G0, G1, G2, G3
from twisted.conch.insults.insults import modes, privateModes
from twisted.conch.insults.insults import (
NORMAL, BOLD, UNDERLINE, BLINK, REVERSE_VIDEO)
from twisted.trial import unittest
WIDTH = 80
HEIGHT = 24
class BufferTests(unittest.TestCase):
def setUp(self):
self.term = helper.TerminalBuffer()
self.term.connectionMade()
def testInitialState(self):
self.assertEqual(self.term.width, WIDTH)
self.assertEqual(self.term.height, HEIGHT)
self.assertEqual(str(self.term),
'\n' * (HEIGHT - 1))
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
def test_initialPrivateModes(self):
"""
Verify that only DEC Auto Wrap Mode (DECAWM) and DEC Text Cursor Enable
Mode (DECTCEM) are initially in the Set Mode (SM) state.
"""
self.assertEqual(
{privateModes.AUTO_WRAP: True,
privateModes.CURSOR_MODE: True},
self.term.privateModes)
def test_carriageReturn(self):
"""
C{"\r"} moves the cursor to the first column in the current row.
"""
self.term.cursorForward(5)
self.term.cursorDown(3)
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
self.term.insertAtCursor("\r")
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
def test_linefeed(self):
"""
C{"\n"} moves the cursor to the next row without changing the column.
"""
self.term.cursorForward(5)
self.assertEqual(self.term.reportCursorPosition(), (5, 0))
self.term.insertAtCursor("\n")
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
def test_newline(self):
"""
C{write} transforms C{"\n"} into C{"\r\n"}.
"""
self.term.cursorForward(5)
self.term.cursorDown(3)
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
self.term.write("\n")
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
def test_setPrivateModes(self):
"""
Verify that L{helper.TerminalBuffer.setPrivateModes} changes the Set
Mode (SM) state to "set" for the private modes it is passed.
"""
expected = self.term.privateModes.copy()
self.term.setPrivateModes([privateModes.SCROLL, privateModes.SCREEN])
expected[privateModes.SCROLL] = True
expected[privateModes.SCREEN] = True
self.assertEqual(expected, self.term.privateModes)
def test_resetPrivateModes(self):
"""
Verify that L{helper.TerminalBuffer.resetPrivateModes} changes the Set
Mode (SM) state to "reset" for the private modes it is passed.
"""
expected = self.term.privateModes.copy()
self.term.resetPrivateModes([privateModes.AUTO_WRAP, privateModes.CURSOR_MODE])
del expected[privateModes.AUTO_WRAP]
del expected[privateModes.CURSOR_MODE]
self.assertEqual(expected, self.term.privateModes)
def testCursorDown(self):
self.term.cursorDown(3)
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
self.term.cursorDown()
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
self.term.cursorDown(HEIGHT)
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
def testCursorUp(self):
self.term.cursorUp(5)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
self.term.cursorDown(20)
self.term.cursorUp(1)
self.assertEqual(self.term.reportCursorPosition(), (0, 19))
self.term.cursorUp(19)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
def testCursorForward(self):
self.term.cursorForward(2)
self.assertEqual(self.term.reportCursorPosition(), (2, 0))
self.term.cursorForward(2)
self.assertEqual(self.term.reportCursorPosition(), (4, 0))
self.term.cursorForward(WIDTH)
self.assertEqual(self.term.reportCursorPosition(), (WIDTH, 0))
def testCursorBackward(self):
self.term.cursorForward(10)
self.term.cursorBackward(2)
self.assertEqual(self.term.reportCursorPosition(), (8, 0))
self.term.cursorBackward(7)
self.assertEqual(self.term.reportCursorPosition(), (1, 0))
self.term.cursorBackward(1)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
self.term.cursorBackward(1)
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
def testCursorPositioning(self):
self.term.cursorPosition(3, 9)
self.assertEqual(self.term.reportCursorPosition(), (3, 9))
def testSimpleWriting(self):
s = "Hello, world."
self.term.write(s)
self.assertEqual(
str(self.term),
s + '\n' +
'\n' * (HEIGHT - 2))
def testOvertype(self):
s = "hello, world."
self.term.write(s)
self.term.cursorBackward(len(s))
self.term.resetModes([modes.IRM])
self.term.write("H")
self.assertEqual(
str(self.term),
("H" + s[1:]) + '\n' +
'\n' * (HEIGHT - 2))
def testInsert(self):
s = "ello, world."
self.term.write(s)
self.term.cursorBackward(len(s))
self.term.setModes([modes.IRM])
self.term.write("H")
self.assertEqual(
str(self.term),
("H" + s) + '\n' +
'\n' * (HEIGHT - 2))
def testWritingInTheMiddle(self):
s = "Hello, world."
self.term.cursorDown(5)
self.term.cursorForward(5)
self.term.write(s)
self.assertEqual(
str(self.term),
'\n' * 5 +
(self.term.fill * 5) + s + '\n' +
'\n' * (HEIGHT - 7))
def testWritingWrappedAtEndOfLine(self):
s = "Hello, world."
self.term.cursorForward(WIDTH - 5)
self.term.write(s)
self.assertEqual(
str(self.term),
s[:5].rjust(WIDTH) + '\n' +
s[5:] + '\n' +
'\n' * (HEIGHT - 3))
def testIndex(self):
self.term.index()
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
self.term.cursorDown(HEIGHT)
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
self.term.index()
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
def testReverseIndex(self):
self.term.reverseIndex()
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
self.term.cursorDown(2)
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
self.term.reverseIndex()
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
def test_nextLine(self):
"""
C{nextLine} positions the cursor at the beginning of the row below the
current row.
"""
self.term.nextLine()
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
self.term.cursorForward(5)
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
self.term.nextLine()
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
def testSaveCursor(self):
self.term.cursorDown(5)
self.term.cursorForward(7)
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
self.term.saveCursor()
self.term.cursorDown(7)
self.term.cursorBackward(3)
self.assertEqual(self.term.reportCursorPosition(), (4, 12))
self.term.restoreCursor()
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
def testSingleShifts(self):
self.term.singleShift2()
self.term.write('Hi')
ch = self.term.getCharacter(0, 0)
self.assertEqual(ch[0], 'H')
self.assertEqual(ch[1].charset, G2)
ch = self.term.getCharacter(1, 0)
self.assertEqual(ch[0], 'i')
self.assertEqual(ch[1].charset, G0)
self.term.singleShift3()
self.term.write('!!')
ch = self.term.getCharacter(2, 0)
self.assertEqual(ch[0], '!')
self.assertEqual(ch[1].charset, G3)
ch = self.term.getCharacter(3, 0)
self.assertEqual(ch[0], '!')
self.assertEqual(ch[1].charset, G0)
def testShifting(self):
s1 = "Hello"
s2 = "World"
s3 = "Bye!"
self.term.write("Hello\n")
self.term.shiftOut()
self.term.write("World\n")
self.term.shiftIn()
self.term.write("Bye!\n")
g = G0
h = 0
for s in (s1, s2, s3):
for i in range(len(s)):
ch = self.term.getCharacter(i, h)
self.assertEqual(ch[0], s[i])
self.assertEqual(ch[1].charset, g)
g = g == G0 and G1 or G0
h += 1
def testGraphicRendition(self):
self.term.selectGraphicRendition(BOLD, UNDERLINE, BLINK, REVERSE_VIDEO)
self.term.write('W')
self.term.selectGraphicRendition(NORMAL)
self.term.write('X')
self.term.selectGraphicRendition(BLINK)
self.term.write('Y')
self.term.selectGraphicRendition(BOLD)
self.term.write('Z')
ch = self.term.getCharacter(0, 0)
self.assertEqual(ch[0], 'W')
self.assertTrue(ch[1].bold)
self.assertTrue(ch[1].underline)
self.assertTrue(ch[1].blink)
self.assertTrue(ch[1].reverseVideo)
ch = self.term.getCharacter(1, 0)
self.assertEqual(ch[0], 'X')
self.assertFalse(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].blink)
self.assertFalse(ch[1].reverseVideo)
ch = self.term.getCharacter(2, 0)
self.assertEqual(ch[0], 'Y')
self.assertTrue(ch[1].blink)
self.assertFalse(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].reverseVideo)
ch = self.term.getCharacter(3, 0)
self.assertEqual(ch[0], 'Z')
self.assertTrue(ch[1].blink)
self.assertTrue(ch[1].bold)
self.assertFalse(ch[1].underline)
self.assertFalse(ch[1].reverseVideo)
def testColorAttributes(self):
s1 = "Merry xmas"
s2 = "Just kidding"
self.term.selectGraphicRendition(helper.FOREGROUND + helper.RED,
helper.BACKGROUND + helper.GREEN)
self.term.write(s1 + "\n")
self.term.selectGraphicRendition(NORMAL)
self.term.write(s2 + "\n")
for i in range(len(s1)):
ch = self.term.getCharacter(i, 0)
self.assertEqual(ch[0], s1[i])
self.assertEqual(ch[1].charset, G0)
self.assertEqual(ch[1].bold, False)
self.assertEqual(ch[1].underline, False)
self.assertEqual(ch[1].blink, False)
self.assertEqual(ch[1].reverseVideo, False)
self.assertEqual(ch[1].foreground, helper.RED)
self.assertEqual(ch[1].background, helper.GREEN)
for i in range(len(s2)):
ch = self.term.getCharacter(i, 1)
self.assertEqual(ch[0], s2[i])
self.assertEqual(ch[1].charset, G0)
self.assertEqual(ch[1].bold, False)
self.assertEqual(ch[1].underline, False)
self.assertEqual(ch[1].blink, False)
self.assertEqual(ch[1].reverseVideo, False)
self.assertEqual(ch[1].foreground, helper.WHITE)
self.assertEqual(ch[1].background, helper.BLACK)
def testEraseLine(self):
s1 = 'line 1'
s2 = 'line 2'
s3 = 'line 3'
self.term.write('\n'.join((s1, s2, s3)) + '\n')
self.term.cursorPosition(1, 1)
self.term.eraseLine()
self.assertEqual(
str(self.term),
s1 + '\n' +
'\n' +
s3 + '\n' +
'\n' * (HEIGHT - 4))
def testEraseToLineEnd(self):
s = 'Hello, world.'
self.term.write(s)
self.term.cursorBackward(5)
self.term.eraseToLineEnd()
self.assertEqual(
str(self.term),
s[:-5] + '\n' +
'\n' * (HEIGHT - 2))
def testEraseToLineBeginning(self):
s = 'Hello, world.'
self.term.write(s)
self.term.cursorBackward(5)
self.term.eraseToLineBeginning()
self.assertEqual(
str(self.term),
s[-4:].rjust(len(s)) + '\n' +
'\n' * (HEIGHT - 2))
def testEraseDisplay(self):
self.term.write('Hello world\n')
self.term.write('Goodbye world\n')
self.term.eraseDisplay()
self.assertEqual(
str(self.term),
'\n' * (HEIGHT - 1))
def testEraseToDisplayEnd(self):
s1 = "Hello world"
s2 = "Goodbye world"
self.term.write('\n'.join((s1, s2, '')))
self.term.cursorPosition(5, 1)
self.term.eraseToDisplayEnd()
self.assertEqual(
str(self.term),
s1 + '\n' +
s2[:5] + '\n' +
'\n' * (HEIGHT - 3))
def testEraseToDisplayBeginning(self):
s1 = "Hello world"
s2 = "Goodbye world"
self.term.write('\n'.join((s1, s2)))
self.term.cursorPosition(5, 1)
self.term.eraseToDisplayBeginning()
self.assertEqual(
str(self.term),
'\n' +
s2[6:].rjust(len(s2)) + '\n' +
'\n' * (HEIGHT - 3))
def testLineInsertion(self):
s1 = "Hello world"
s2 = "Goodbye world"
self.term.write('\n'.join((s1, s2)))
self.term.cursorPosition(7, 1)
self.term.insertLine()
self.assertEqual(
str(self.term),
s1 + '\n' +
'\n' +
s2 + '\n' +
'\n' * (HEIGHT - 4))
def testLineDeletion(self):
s1 = "Hello world"
s2 = "Middle words"
s3 = "Goodbye world"
self.term.write('\n'.join((s1, s2, s3)))
self.term.cursorPosition(9, 1)
self.term.deleteLine()
self.assertEqual(
str(self.term),
s1 + '\n' +
s3 + '\n' +
'\n' * (HEIGHT - 3))
class FakeDelayedCall:
called = False
cancelled = False
def __init__(self, fs, timeout, f, a, kw):
self.fs = fs
self.timeout = timeout
self.f = f
self.a = a
self.kw = kw
def active(self):
return not (self.cancelled or self.called)
def cancel(self):
self.cancelled = True
# self.fs.calls.remove(self)
def call(self):
self.called = True
self.f(*self.a, **self.kw)
class FakeScheduler:
def __init__(self):
self.calls = []
def callLater(self, timeout, f, *a, **kw):
self.calls.append(FakeDelayedCall(self, timeout, f, a, kw))
return self.calls[-1]
class ExpectTests(unittest.TestCase):
def setUp(self):
self.term = helper.ExpectableBuffer()
self.term.connectionMade()
self.fs = FakeScheduler()
def testSimpleString(self):
result = []
d = self.term.expect("hello world", timeout=1, scheduler=self.fs)
d.addCallback(result.append)
self.term.write("greeting puny earthlings\n")
self.assertFalse(result)
self.term.write("hello world\n")
self.assertTrue(result)
self.assertEqual(result[0].group(), "hello world")
self.assertEqual(len(self.fs.calls), 1)
self.assertFalse(self.fs.calls[0].active())
def testBrokenUpString(self):
result = []
d = self.term.expect("hello world")
d.addCallback(result.append)
self.assertFalse(result)
self.term.write("hello ")
self.assertFalse(result)
self.term.write("worl")
self.assertFalse(result)
self.term.write("d")
self.assertTrue(result)
self.assertEqual(result[0].group(), "hello world")
def testMultiple(self):
result = []
d1 = self.term.expect("hello ")
d1.addCallback(result.append)
d2 = self.term.expect("world")
d2.addCallback(result.append)
self.assertFalse(result)
self.term.write("hello")
self.assertFalse(result)
self.term.write(" ")
self.assertEqual(len(result), 1)
self.term.write("world")
self.assertEqual(len(result), 2)
self.assertEqual(result[0].group(), "hello ")
self.assertEqual(result[1].group(), "world")
def testSynchronous(self):
self.term.write("hello world")
result = []
d = self.term.expect("hello world")
d.addCallback(result.append)
self.assertTrue(result)
self.assertEqual(result[0].group(), "hello world")
def testMultipleSynchronous(self):
self.term.write("goodbye world")
result = []
d1 = self.term.expect("bye")
d1.addCallback(result.append)
d2 = self.term.expect("world")
d2.addCallback(result.append)
self.assertEqual(len(result), 2)
self.assertEqual(result[0].group(), "bye")
self.assertEqual(result[1].group(), "world")
def _cbTestTimeoutFailure(self, res):
self.assertTrue(hasattr(res, 'type'))
self.assertEqual(res.type, helper.ExpectationTimeout)
def testTimeoutFailure(self):
d = self.term.expect("hello world", timeout=1, scheduler=self.fs)
d.addBoth(self._cbTestTimeoutFailure)
self.fs.calls[0].call()
def testOverlappingTimeout(self):
self.term.write("not zoomtastic")
result = []
d1 = self.term.expect("hello world", timeout=1, scheduler=self.fs)
d1.addBoth(self._cbTestTimeoutFailure)
d2 = self.term.expect("zoom")
d2.addCallback(result.append)
self.fs.calls[0].call()
self.assertEqual(len(result), 1)
self.assertEqual(result[0].group(), "zoom")
class CharacterAttributeTests(unittest.TestCase):
"""
Tests for L{twisted.conch.insults.helper.CharacterAttribute}.
"""
def test_equality(self):
"""
L{CharacterAttribute}s must have matching character attribute values
(bold, blink, underline, etc) with the same values to be considered
equal.
"""
self.assertEqual(
helper.CharacterAttribute(),
helper.CharacterAttribute())
self.assertEqual(
helper.CharacterAttribute(),
helper.CharacterAttribute(charset=G0))
self.assertEqual(
helper.CharacterAttribute(
bold=True, underline=True, blink=False, reverseVideo=True,
foreground=helper.BLUE),
helper.CharacterAttribute(
bold=True, underline=True, blink=False, reverseVideo=True,
foreground=helper.BLUE))
self.assertNotEqual(
helper.CharacterAttribute(),
helper.CharacterAttribute(charset=G1))
self.assertNotEqual(
helper.CharacterAttribute(bold=True),
helper.CharacterAttribute(bold=False))
def test_wantOneDeprecated(self):
"""
L{twisted.conch.insults.helper.CharacterAttribute.wantOne} emits
a deprecation warning when invoked.
"""
# Trigger the deprecation warning.
helper._FormattingState().wantOne(bold=True)
warningsShown = self.flushWarnings([self.test_wantOneDeprecated])
self.assertEqual(len(warningsShown), 1)
self.assertEqual(warningsShown[0]['category'], DeprecationWarning)
self.assertEqual(
warningsShown[0]['message'],
'twisted.conch.insults.helper.wantOne was deprecated in '
'Twisted 13.1.0')

View File

@ -0,0 +1,497 @@
# -*- test-case-name: twisted.conch.test.test_insults -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.python.reflect import namedAny
from twisted.trial import unittest
from twisted.test.proto_helpers import StringTransport
from twisted.conch.insults.insults import ServerProtocol, ClientProtocol
from twisted.conch.insults.insults import CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL
from twisted.conch.insults.insults import G0, G1
from twisted.conch.insults.insults import modes
def _getattr(mock, name):
return super(Mock, mock).__getattribute__(name)
def occurrences(mock):
return _getattr(mock, 'occurrences')
def methods(mock):
return _getattr(mock, 'methods')
def _append(mock, obj):
occurrences(mock).append(obj)
default = object()
class Mock(object):
callReturnValue = default
def __init__(self, methods=None, callReturnValue=default):
"""
@param methods: Mapping of names to return values
@param callReturnValue: object __call__ should return
"""
self.occurrences = []
if methods is None:
methods = {}
self.methods = methods
if callReturnValue is not default:
self.callReturnValue = callReturnValue
def __call__(self, *a, **kw):
returnValue = _getattr(self, 'callReturnValue')
if returnValue is default:
returnValue = Mock()
# _getattr(self, 'occurrences').append(('__call__', returnValue, a, kw))
_append(self, ('__call__', returnValue, a, kw))
return returnValue
def __getattribute__(self, name):
methods = _getattr(self, 'methods')
if name in methods:
attrValue = Mock(callReturnValue=methods[name])
else:
attrValue = Mock()
# _getattr(self, 'occurrences').append((name, attrValue))
_append(self, (name, attrValue))
return attrValue
class MockMixin:
def assertCall(self, occurrence, methodName, expectedPositionalArgs=(),
expectedKeywordArgs={}):
attr, mock = occurrence
self.assertEqual(attr, methodName)
self.assertEqual(len(occurrences(mock)), 1)
[(call, result, args, kw)] = occurrences(mock)
self.assertEqual(call, "__call__")
self.assertEqual(args, expectedPositionalArgs)
self.assertEqual(kw, expectedKeywordArgs)
return result
_byteGroupingTestTemplate = """\
def testByte%(groupName)s(self):
transport = StringTransport()
proto = Mock()
parser = self.protocolFactory(lambda: proto)
parser.factory = self
parser.makeConnection(transport)
bytes = self.TEST_BYTES
while bytes:
chunk = bytes[:%(bytesPer)d]
bytes = bytes[%(bytesPer)d:]
parser.dataReceived(chunk)
self.verifyResults(transport, proto, parser)
"""
class ByteGroupingsMixin(MockMixin):
protocolFactory = None
for word, n in [('Pairs', 2), ('Triples', 3), ('Quads', 4), ('Quints', 5), ('Sexes', 6)]:
exec _byteGroupingTestTemplate % {'groupName': word, 'bytesPer': n}
del word, n
def verifyResults(self, transport, proto, parser):
result = self.assertCall(occurrences(proto).pop(0), "makeConnection", (parser,))
self.assertEqual(occurrences(result), [])
del _byteGroupingTestTemplate
class ServerArrowKeysTests(ByteGroupingsMixin, unittest.TestCase):
protocolFactory = ServerProtocol
# All the arrow keys once
TEST_BYTES = '\x1b[A\x1b[B\x1b[C\x1b[D'
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for arrow in (parser.UP_ARROW, parser.DOWN_ARROW,
parser.RIGHT_ARROW, parser.LEFT_ARROW):
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (arrow, None))
self.assertEqual(occurrences(result), [])
self.assertFalse(occurrences(proto))
class PrintableCharactersTests(ByteGroupingsMixin, unittest.TestCase):
protocolFactory = ServerProtocol
# Some letters and digits, first on their own, then capitalized,
# then modified with alt
TEST_BYTES = 'abc123ABC!@#\x1ba\x1bb\x1bc\x1b1\x1b2\x1b3'
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for char in 'abc123ABC!@#':
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, None))
self.assertEqual(occurrences(result), [])
for char in 'abc123':
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, parser.ALT))
self.assertEqual(occurrences(result), [])
occs = occurrences(proto)
self.assertFalse(occs, "%r should have been []" % (occs,))
class ServerFunctionKeysTests(ByteGroupingsMixin, unittest.TestCase):
"""Test for parsing and dispatching function keys (F1 - F12)
"""
protocolFactory = ServerProtocol
byteList = []
for bytes in ('OP', 'OQ', 'OR', 'OS', # F1 - F4
'15~', '17~', '18~', '19~', # F5 - F8
'20~', '21~', '23~', '24~'): # F9 - F12
byteList.append('\x1b[' + bytes)
TEST_BYTES = ''.join(byteList)
del byteList, bytes
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for funcNum in range(1, 13):
funcArg = getattr(parser, 'F%d' % (funcNum,))
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (funcArg, None))
self.assertEqual(occurrences(result), [])
self.assertFalse(occurrences(proto))
class ClientCursorMovementTests(ByteGroupingsMixin, unittest.TestCase):
protocolFactory = ClientProtocol
d2 = "\x1b[2B"
r4 = "\x1b[4C"
u1 = "\x1b[A"
l2 = "\x1b[2D"
# Move the cursor down two, right four, up one, left two, up one, left two
TEST_BYTES = d2 + r4 + u1 + l2 + u1 + l2
del d2, r4, u1, l2
def verifyResults(self, transport, proto, parser):
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
for (method, count) in [('Down', 2), ('Forward', 4), ('Up', 1),
('Backward', 2), ('Up', 1), ('Backward', 2)]:
result = self.assertCall(occurrences(proto).pop(0), "cursor" + method, (count,))
self.assertEqual(occurrences(result), [])
self.assertFalse(occurrences(proto))
class ClientControlSequencesTests(unittest.TestCase, MockMixin):
def setUp(self):
self.transport = StringTransport()
self.proto = Mock()
self.parser = ClientProtocol(lambda: self.proto)
self.parser.factory = self
self.parser.makeConnection(self.transport)
result = self.assertCall(occurrences(self.proto).pop(0), "makeConnection", (self.parser,))
self.assertFalse(occurrences(result))
def testSimpleCardinals(self):
self.parser.dataReceived(
''.join([''.join(['\x1b[' + str(n) + ch for n in ('', 2, 20, 200)]) for ch in 'BACD']))
occs = occurrences(self.proto)
for meth in ("Down", "Up", "Forward", "Backward"):
for count in (1, 2, 20, 200):
result = self.assertCall(occs.pop(0), "cursor" + meth, (count,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testScrollRegion(self):
self.parser.dataReceived('\x1b[5;22r\x1b[r')
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "setScrollRegion", (5, 22))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "setScrollRegion", (None, None))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testHeightAndWidth(self):
self.parser.dataReceived("\x1b#3\x1b#4\x1b#5\x1b#6")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "doubleHeightLine", (True,))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "doubleHeightLine", (False,))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "singleWidthLine")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "doubleWidthLine")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testCharacterSet(self):
self.parser.dataReceived(
''.join([''.join(['\x1b' + g + n for n in 'AB012']) for g in '()']))
occs = occurrences(self.proto)
for which in (G0, G1):
for charset in (CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL):
result = self.assertCall(occs.pop(0), "selectCharacterSet", (charset, which))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testShifting(self):
self.parser.dataReceived("\x15\x14")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "shiftIn")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "shiftOut")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testSingleShifts(self):
self.parser.dataReceived("\x1bN\x1bO")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "singleShift2")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "singleShift3")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testKeypadMode(self):
self.parser.dataReceived("\x1b=\x1b>")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "applicationKeypadMode")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "numericKeypadMode")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testCursor(self):
self.parser.dataReceived("\x1b7\x1b8")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "saveCursor")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "restoreCursor")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testReset(self):
self.parser.dataReceived("\x1bc")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "reset")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testIndex(self):
self.parser.dataReceived("\x1bD\x1bM\x1bE")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "index")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "reverseIndex")
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "nextLine")
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testModes(self):
self.parser.dataReceived(
"\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "h")
self.parser.dataReceived(
"\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "l")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "setModes", ([modes.KAM, modes.IRM, modes.LNM],))
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "resetModes", ([modes.KAM, modes.IRM, modes.LNM],))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testErasure(self):
self.parser.dataReceived(
"\x1b[K\x1b[1K\x1b[2K\x1b[J\x1b[1J\x1b[2J\x1b[3P")
occs = occurrences(self.proto)
for meth in ("eraseToLineEnd", "eraseToLineBeginning", "eraseLine",
"eraseToDisplayEnd", "eraseToDisplayBeginning",
"eraseDisplay"):
result = self.assertCall(occs.pop(0), meth)
self.assertFalse(occurrences(result))
result = self.assertCall(occs.pop(0), "deleteCharacter", (3,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testLineDeletion(self):
self.parser.dataReceived("\x1b[M\x1b[3M")
occs = occurrences(self.proto)
for arg in (1, 3):
result = self.assertCall(occs.pop(0), "deleteLine", (arg,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testLineInsertion(self):
self.parser.dataReceived("\x1b[L\x1b[3L")
occs = occurrences(self.proto)
for arg in (1, 3):
result = self.assertCall(occs.pop(0), "insertLine", (arg,))
self.assertFalse(occurrences(result))
self.assertFalse(occs)
def testCursorPosition(self):
methods(self.proto)['reportCursorPosition'] = (6, 7)
self.parser.dataReceived("\x1b[6n")
self.assertEqual(self.transport.value(), "\x1b[7;8R")
occs = occurrences(self.proto)
result = self.assertCall(occs.pop(0), "reportCursorPosition")
# This isn't really an interesting assert, since it only tests that
# our mock setup is working right, but I'll include it anyway.
self.assertEqual(result, (6, 7))
def test_applicationDataBytes(self):
"""
Contiguous non-control bytes are passed to a single call to the
C{write} method of the terminal to which the L{ClientProtocol} is
connected.
"""
occs = occurrences(self.proto)
self.parser.dataReceived('a')
self.assertCall(occs.pop(0), "write", ("a",))
self.parser.dataReceived('bc')
self.assertCall(occs.pop(0), "write", ("bc",))
def _applicationDataTest(self, data, calls):
occs = occurrences(self.proto)
self.parser.dataReceived(data)
while calls:
self.assertCall(occs.pop(0), *calls.pop(0))
self.assertFalse(occs, "No other calls should happen: %r" % (occs,))
def test_shiftInAfterApplicationData(self):
"""
Application data bytes followed by a shift-in command are passed to a
call to C{write} before the terminal's C{shiftIn} method is called.
"""
self._applicationDataTest(
'ab\x15', [
("write", ("ab",)),
("shiftIn",)])
def test_shiftOutAfterApplicationData(self):
"""
Application data bytes followed by a shift-out command are passed to a
call to C{write} before the terminal's C{shiftOut} method is called.
"""
self._applicationDataTest(
'ab\x14', [
("write", ("ab",)),
("shiftOut",)])
def test_cursorBackwardAfterApplicationData(self):
"""
Application data bytes followed by a cursor-backward command are passed
to a call to C{write} before the terminal's C{cursorBackward} method is
called.
"""
self._applicationDataTest(
'ab\x08', [
("write", ("ab",)),
("cursorBackward",)])
def test_escapeAfterApplicationData(self):
"""
Application data bytes followed by an escape character are passed to a
call to C{write} before the terminal's handler method for the escape is
called.
"""
# Test a short escape
self._applicationDataTest(
'ab\x1bD', [
("write", ("ab",)),
("index",)])
# And a long escape
self._applicationDataTest(
'ab\x1b[4h', [
("write", ("ab",)),
("setModes", ([4],))])
# There's some other cases too, but they're all handled by the same
# codepaths as above.
class ServerProtocolOutputTests(unittest.TestCase):
"""
Tests for the bytes L{ServerProtocol} writes to its transport when its
methods are called.
"""
def test_nextLine(self):
"""
L{ServerProtocol.nextLine} writes C{"\r\n"} to its transport.
"""
# Why doesn't it write ESC E? Because ESC E is poorly supported. For
# example, gnome-terminal (many different versions) fails to scroll if
# it receives ESC E and the cursor is already on the last row.
protocol = ServerProtocol()
transport = StringTransport()
protocol.makeConnection(transport)
protocol.nextLine()
self.assertEqual(transport.value(), "\r\n")
class DeprecationsTests(unittest.TestCase):
"""
Tests to ensure deprecation of L{insults.colors} and L{insults.client}
"""
def ensureDeprecated(self, message):
"""
Ensures that the correct deprecation warning was issued.
"""
warnings = self.flushWarnings()
self.assertIs(warnings[0]['category'], DeprecationWarning)
self.assertEqual(warnings[0]['message'], message)
self.assertEqual(len(warnings), 1)
def test_colors(self):
"""
The L{insults.colors} module is deprecated
"""
namedAny('twisted.conch.insults.colors')
self.ensureDeprecated("twisted.conch.insults.colors was deprecated "
"in Twisted 10.1.0: Please use "
"twisted.conch.insults.helper instead.")
def test_client(self):
"""
The L{insults.client} module is deprecated
"""
namedAny('twisted.conch.insults.client')
self.ensureDeprecated("twisted.conch.insults.client was deprecated "
"in Twisted 10.1.0: Please use "
"twisted.conch.insults.insults instead.")

Some files were not shown because too many files have changed in this diff Show More