mirror of
https://github.com/JamesonHuang/OpenWrt_Luci_Lua.git
synced 2025-02-21 02:20:18 +00:00
some source code about twisted
This commit is contained in:
parent
8791e644e5
commit
dff86defdb
9
1_7.http_proxy_server/python/Twisted-15.2.1-source/.gitignore
vendored
Normal file
9
1_7.http_proxy_server/python/Twisted-15.2.1-source/.gitignore
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
*.egg-info/
|
||||
*.o
|
||||
*.py[co]
|
||||
*.so
|
||||
_trial_temp*/
|
||||
build/
|
||||
dropin.cache
|
||||
doc/
|
||||
docs/_build/
|
@ -0,0 +1,21 @@
|
||||
Contributing to Twisted
|
||||
=======================
|
||||
|
||||
As an open source project, Twisted welcomes contributions of many forms.
|
||||
|
||||
Examples of contributions include:
|
||||
|
||||
* Code patches
|
||||
* Documentation improvements
|
||||
* Bug reports and patch reviews
|
||||
|
||||
Extensive contribution guidelines are available online at:
|
||||
|
||||
https://twistedmatrix.com/trac/wiki/ContributingToTwistedLabs
|
||||
|
||||
**Warning: pull requests are ignored!** File a ticket at:
|
||||
|
||||
https://twistedmatrix.com/trac/newticket
|
||||
|
||||
Twisted uses Trac to keep track of bugs, feature requests, and associated
|
||||
patches because GitHub doesn't provide adequate tooling for its community.
|
32
1_7.http_proxy_server/python/Twisted-15.2.1-source/INSTALL
Normal file
32
1_7.http_proxy_server/python/Twisted-15.2.1-source/INSTALL
Normal file
@ -0,0 +1,32 @@
|
||||
Requirements
|
||||
|
||||
Python 2.6 or 2.7.
|
||||
|
||||
Zope Interface 3.6.0 or better (http://pypi.python.org/pypi/zope.interface)
|
||||
|
||||
pyOpenSSL (<http://launchpad.net/pyopenssl>) is required for any SSL APIs.
|
||||
Version 0.10 or newer is required.
|
||||
|
||||
On Windows pywin32 (<http://sourceforge.net/projects/pywin32/files/>) is
|
||||
required. Build 215 or later is highly recommended for reliable operation
|
||||
(this is already included in ActivePython).
|
||||
|
||||
If you would like to use Trial's subunit reporter, then you will need to
|
||||
install Subunit 0.0.2 or later (https://launchpad.net/subunit).
|
||||
|
||||
Installation
|
||||
|
||||
* Debian and Ubuntu
|
||||
Packages are included in the main distribution.
|
||||
|
||||
* FreeBSD, Gentoo
|
||||
Twisted is in their package repositories.
|
||||
|
||||
* Win32
|
||||
Installers are available from http://twistedmatrix.com/
|
||||
|
||||
* Other
|
||||
As with other Python packages, the standard way of installing from source
|
||||
is:
|
||||
|
||||
python setup.py install
|
67
1_7.http_proxy_server/python/Twisted-15.2.1-source/LICENSE
Normal file
67
1_7.http_proxy_server/python/Twisted-15.2.1-source/LICENSE
Normal file
@ -0,0 +1,67 @@
|
||||
Copyright (c) 2001-2015
|
||||
Allen Short
|
||||
Andy Gayton
|
||||
Andrew Bennetts
|
||||
Antoine Pitrou
|
||||
Apple Computer, Inc.
|
||||
Ashwini Oruganti
|
||||
Benjamin Bruheim
|
||||
Bob Ippolito
|
||||
Canonical Limited
|
||||
Christopher Armstrong
|
||||
David Reid
|
||||
Donovan Preston
|
||||
Eric Mangold
|
||||
Eyal Lotem
|
||||
Google Inc.
|
||||
Hawkie Owl
|
||||
Hybrid Logic Ltd.
|
||||
Hynek Schlawack
|
||||
Itamar Turner-Trauring
|
||||
James Knight
|
||||
Jason A. Mobarak
|
||||
Jean-Paul Calderone
|
||||
Jessica McKellar
|
||||
Jonathan Jacobs
|
||||
Jonathan Lange
|
||||
Jonathan D. Simms
|
||||
Jürgen Hermann
|
||||
Julian Berman
|
||||
Kevin Horn
|
||||
Kevin Turner
|
||||
Laurens Van Houtven
|
||||
Mary Gardiner
|
||||
Matthew Lefkowitz
|
||||
Massachusetts Institute of Technology
|
||||
Moshe Zadka
|
||||
Paul Swartz
|
||||
Pavel Pergamenshchik
|
||||
Ralph Meijer
|
||||
Richard Wall
|
||||
Sean Riley
|
||||
Software Freedom Conservancy
|
||||
Tavendo GmbH
|
||||
Travis B. Hartwell
|
||||
Thijs Triemstra
|
||||
Thomas Herve
|
||||
Timothy Allen
|
||||
Tom Prince
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of this software and associated documentation files (the
|
||||
"Software"), to deal in the Software without restriction, including
|
||||
without limitation the rights to use, copy, modify, merge, publish,
|
||||
distribute, sublicense, and/or sell copies of the Software, and to
|
||||
permit persons to whom the Software is furnished to do so, subject to
|
||||
the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
5102
1_7.http_proxy_server/python/Twisted-15.2.1-source/NEWS
Normal file
5102
1_7.http_proxy_server/python/Twisted-15.2.1-source/NEWS
Normal file
File diff suppressed because it is too large
Load Diff
114
1_7.http_proxy_server/python/Twisted-15.2.1-source/README
Normal file
114
1_7.http_proxy_server/python/Twisted-15.2.1-source/README
Normal file
@ -0,0 +1,114 @@
|
||||
Twisted 15.2.1
|
||||
|
||||
Quote of the Release:
|
||||
|
||||
<hynek> is there a race condition in threading tests? how could that happen :>
|
||||
|
||||
|
||||
For information on what's new in Twisted 15.2.1, see the NEWS file that comes
|
||||
with the distribution.
|
||||
|
||||
What is this?
|
||||
=============
|
||||
|
||||
Twisted is an event-based framework for internet applications. It includes
|
||||
modules for many different purposes, including the following:
|
||||
|
||||
- twisted.application
|
||||
A "Service" system that allows you to organize your application in
|
||||
hierarchies with well-defined startup and dependency semantics,
|
||||
- twisted.cred
|
||||
A general credentials and authentication system that facilitates
|
||||
pluggable authentication backends,
|
||||
- twisted.enterprise
|
||||
Asynchronous database access, compatible with any Python DBAPI2.0
|
||||
modules,
|
||||
- twisted.internet
|
||||
Low-level asynchronous networking APIs that allow you to define
|
||||
your own protocols that run over certain transports,
|
||||
- twisted.manhole
|
||||
A tool for remote debugging of your services which gives you a
|
||||
Python interactive interpreter,
|
||||
- twisted.protocols
|
||||
Basic protocol implementations and helpers for your own protocol
|
||||
implementations,
|
||||
- twisted.python
|
||||
A large set of utilities for Python tricks, reflection, text
|
||||
processing, and anything else,
|
||||
- twisted.spread
|
||||
A secure, fast remote object system,
|
||||
- twisted.trial
|
||||
A unit testing framework that integrates well with Twisted-based code.
|
||||
|
||||
Twisted supports integration of the Win32, Tk, GTK+ and GTK+ 2 event loops
|
||||
with its main event loop. There is experimental support for Mac OS X and
|
||||
wxPython event loop integration, which you use at your peril.
|
||||
|
||||
For more information, visit http://www.twistedmatrix.com, or join the list
|
||||
at http://twistedmatrix.com/cgi-bin/mailman/listinfo/twisted-python
|
||||
|
||||
There are many official Twisted subprojects, including clients and
|
||||
servers for web, mail, DNS, and more. You can find out more about
|
||||
these projects at http://twistedmatrix.com/trac/wiki/TwistedProjects
|
||||
|
||||
|
||||
Installing
|
||||
==========
|
||||
|
||||
Instructions for installing this software are in INSTALL.
|
||||
|
||||
Unit Tests
|
||||
==========
|
||||
|
||||
See our unit tests run proving that the software is BugFree(TM):
|
||||
|
||||
% trial twisted
|
||||
|
||||
Some of these tests may fail if you
|
||||
* don't have the dependancies required for a particular subsystem installed,
|
||||
* have a firewall blocking some ports (or things like Multicast, which Linux
|
||||
NAT has shown itself to do), or
|
||||
* run them as root.
|
||||
|
||||
|
||||
Documentation and Support
|
||||
=========================
|
||||
|
||||
Twisted's documentation is available from the Twisted Matrix website:
|
||||
|
||||
http://twistedmatrix.com/documents/current/
|
||||
|
||||
This documentation contains how-tos, code examples, and an API reference.
|
||||
|
||||
Help is also available on the Twisted mailing list:
|
||||
|
||||
http://twistedmatrix.com/cgi-bin/mailman/listinfo/twisted-python
|
||||
|
||||
There is also a pair of very lively IRC channels, #twisted (for general
|
||||
Twisted questions) and #twisted.web (for Twisted Web), on chat.freenode.net.
|
||||
|
||||
|
||||
Copyright
|
||||
=========
|
||||
|
||||
All of the code in this distribution is Copyright (c) 2001-2015
|
||||
Twisted Matrix Laboratories.
|
||||
|
||||
Twisted is made available under the MIT license. The included
|
||||
LICENSE file describes this in detail.
|
||||
|
||||
|
||||
Warranty
|
||||
========
|
||||
|
||||
THIS SOFTWARE IS PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER
|
||||
EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
|
||||
OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
|
||||
TO THE USE OF THIS SOFTWARE IS WITH YOU.
|
||||
|
||||
IN NO EVENT WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY
|
||||
AND/OR REDISTRIBUTE THE LIBRARY, BE LIABLE TO YOU FOR ANY DAMAGES, EVEN IF
|
||||
SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH
|
||||
DAMAGES.
|
||||
|
||||
Again, see the included LICENSE file for specific legal details.
|
@ -0,0 +1,19 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# This makes sure that users don't have to set up their environment
|
||||
# specially in order to run these programs from bin/.
|
||||
|
||||
# This helper is shared by many different actual scripts. It is not intended to
|
||||
# be packaged or installed, it is only a developer convenience. By the time
|
||||
# Twisted is actually installed somewhere, the environment should already be set
|
||||
# up properly without the help of this tool.
|
||||
|
||||
import sys, os
|
||||
|
||||
path = os.path.abspath(sys.argv[0])
|
||||
while os.path.dirname(path) != path:
|
||||
if os.path.exists(os.path.join(path, 'twisted', '__init__.py')):
|
||||
sys.path.insert(0, path)
|
||||
break
|
||||
path = os.path.dirname(path)
|
15
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/conch/cftp
Executable file
15
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/conch/cftp
Executable file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import sys, os
|
||||
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
|
||||
sys.path.insert(0, extra)
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
sys.path.remove(extra)
|
||||
|
||||
from twisted.conch.scripts.cftp import run
|
||||
run()
|
15
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/conch/ckeygen
Executable file
15
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/conch/ckeygen
Executable file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import sys, os
|
||||
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
|
||||
sys.path.insert(0, extra)
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
sys.path.remove(extra)
|
||||
|
||||
from twisted.conch.scripts.ckeygen import run
|
||||
run()
|
15
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/conch/conch
Executable file
15
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/conch/conch
Executable file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import sys, os
|
||||
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
|
||||
sys.path.insert(0, extra)
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
sys.path.remove(extra)
|
||||
|
||||
from twisted.conch.scripts.conch import run
|
||||
run()
|
15
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/conch/tkconch
Executable file
15
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/conch/tkconch
Executable file
@ -0,0 +1,15 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import sys, os
|
||||
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
|
||||
sys.path.insert(0, extra)
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
sys.path.remove(extra)
|
||||
|
||||
from twisted.conch.scripts.tkconch import run
|
||||
run()
|
16
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/lore/lore
Executable file
16
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/lore/lore
Executable file
@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import sys, os
|
||||
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
|
||||
sys.path.insert(0, extra)
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
sys.path.remove(extra)
|
||||
|
||||
from twisted.lore.scripts.lore import run
|
||||
run()
|
||||
|
20
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/mail/mailmail
Executable file
20
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/mail/mailmail
Executable file
@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This script attempts to send some email.
|
||||
"""
|
||||
|
||||
import sys, os
|
||||
extra = os.path.dirname(os.path.dirname(sys.argv[0]))
|
||||
sys.path.insert(0, extra)
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
sys.path.remove(extra)
|
||||
|
||||
from twisted.mail.scripts import mailmail
|
||||
mailmail.run()
|
||||
|
16
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/manhole
Executable file
16
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/manhole
Executable file
@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This script runs GtkManhole, a client for Twisted.Manhole
|
||||
"""
|
||||
import sys
|
||||
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
|
||||
from twisted.scripts import manhole
|
||||
manhole.run()
|
12
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/pyhtmlizer
Executable file
12
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/pyhtmlizer
Executable file
@ -0,0 +1,12 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
import sys
|
||||
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
|
||||
from twisted.scripts.htmlizer import run
|
||||
run()
|
16
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/tap2deb
Executable file
16
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/tap2deb
Executable file
@ -0,0 +1,16 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
tap2deb
|
||||
"""
|
||||
import sys
|
||||
|
||||
try:
|
||||
__import__('_preamble')
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
|
||||
from twisted.scripts import tap2deb
|
||||
tap2deb.run()
|
19
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/tap2rpm
Executable file
19
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/tap2rpm
Executable file
@ -0,0 +1,19 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# based off the tap2deb code
|
||||
# tap2rpm built by Sean Reifschneider, <jafo@tummy.com>
|
||||
|
||||
"""
|
||||
tap2rpm
|
||||
"""
|
||||
import sys
|
||||
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
|
||||
from twisted.scripts import tap2rpm
|
||||
tap2rpm.run()
|
18
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/trial
Executable file
18
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/trial
Executable file
@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
import os, sys
|
||||
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
|
||||
# begin chdir armor
|
||||
sys.path[:] = map(os.path.abspath, sys.path)
|
||||
# end chdir armor
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.getcwd()))
|
||||
|
||||
from twisted.scripts.trial import run
|
||||
run()
|
14
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/twistd
Executable file
14
1_7.http_proxy_server/python/Twisted-15.2.1-source/bin/twistd
Executable file
@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
import os, sys
|
||||
|
||||
try:
|
||||
import _preamble
|
||||
except ImportError:
|
||||
sys.exc_clear()
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.getcwd()))
|
||||
|
||||
from twisted.scripts.twistd import run
|
||||
run()
|
76
1_7.http_proxy_server/python/Twisted-15.2.1-source/setup.py
Executable file
76
1_7.http_proxy_server/python/Twisted-15.2.1-source/setup.py
Executable file
@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Distutils installer for Twisted.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Load setuptools, to build a specific source package
|
||||
import setuptools
|
||||
# Tell Twisted not to enforce zope.interface requirement on import, since
|
||||
# we're going to have to import twisted.python.dist and can rely on
|
||||
# setuptools to install dependencies.
|
||||
setuptools._TWISTED_NO_CHECK_REQUIREMENTS = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
Invoke twisted.python.dist with the appropriate metadata about the
|
||||
Twisted package.
|
||||
"""
|
||||
# On Python 3, use setup3.py until Python 3 port is done:
|
||||
if sys.version_info[0] > 2:
|
||||
import setup3
|
||||
setup3.main()
|
||||
return
|
||||
|
||||
if os.path.exists('twisted'):
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
setup_args = {}
|
||||
|
||||
if 'setuptools' in sys.modules:
|
||||
from pkg_resources import parse_requirements
|
||||
requirements = ["zope.interface >= 3.6.0"]
|
||||
try:
|
||||
list(parse_requirements(requirements))
|
||||
except:
|
||||
print("""You seem to be running a very old version of setuptools.
|
||||
This version of setuptools has a bug parsing dependencies, so automatic
|
||||
dependency resolution is disabled.
|
||||
""")
|
||||
else:
|
||||
setup_args['install_requires'] = requirements
|
||||
setup_args['include_package_data'] = True
|
||||
setup_args['zip_safe'] = False
|
||||
|
||||
from twisted.python.dist import (
|
||||
STATIC_PACKAGE_METADATA, getDataFiles, getExtensions, getAllScripts,
|
||||
getPackages, setup, _EXTRAS_REQUIRE)
|
||||
|
||||
scripts = getAllScripts()
|
||||
|
||||
setup_args.update(dict(
|
||||
packages=getPackages('twisted'),
|
||||
conditionalExtensions=getExtensions(),
|
||||
scripts=scripts,
|
||||
extras_require=_EXTRAS_REQUIRE,
|
||||
data_files=getDataFiles('twisted'),
|
||||
**STATIC_PACKAGE_METADATA))
|
||||
|
||||
setup(**setup_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main(sys.argv[1:])
|
||||
except KeyboardInterrupt:
|
||||
sys.exit(1)
|
51
1_7.http_proxy_server/python/Twisted-15.2.1-source/setup3.py
Normal file
51
1_7.http_proxy_server/python/Twisted-15.2.1-source/setup3.py
Normal file
@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python3.3
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# This is a temporary helper to be able to build and install distributions of
|
||||
# Twisted on/for Python 3. Once all of Twisted has been ported, it should go
|
||||
# away and setup.py should work for either Python 2 or Python 3.
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
import sys
|
||||
import os
|
||||
from distutils.command.sdist import sdist
|
||||
|
||||
|
||||
class DisabledSdist(sdist):
|
||||
"""
|
||||
A version of the sdist command that does nothing.
|
||||
"""
|
||||
def run(self):
|
||||
sys.stderr.write(
|
||||
"The sdist command only works with Python 2 at the moment.\n")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
from setuptools import setup
|
||||
except ImportError:
|
||||
from distutils.core import setup
|
||||
|
||||
# Make sure the to-be-installed version of Twisted is used, if available,
|
||||
# since we're importing from it:
|
||||
if os.path.exists('twisted'):
|
||||
sys.path.insert(0, '.')
|
||||
|
||||
from twisted.python.dist3 import modulesToInstall
|
||||
from twisted.python.dist import STATIC_PACKAGE_METADATA
|
||||
|
||||
args = STATIC_PACKAGE_METADATA.copy()
|
||||
args['install_requires'] = ["zope.interface >= 4.0.2"]
|
||||
args['py_modules'] = modulesToInstall
|
||||
args['cmdclass'] = {'sdist': DisabledSdist}
|
||||
|
||||
setup(**args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,71 @@
|
||||
# -*- test-case-name: twisted -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Twisted: The Framework Of Your Internet.
|
||||
"""
|
||||
|
||||
def _checkRequirements():
|
||||
# Don't allow the user to run a version of Python we don't support.
|
||||
import sys
|
||||
|
||||
version = getattr(sys, "version_info", (0,))
|
||||
if version < (2, 6):
|
||||
raise ImportError("Twisted requires Python 2.6 or later.")
|
||||
if version < (3, 0):
|
||||
required = "3.6.0"
|
||||
else:
|
||||
required = "4.0.0"
|
||||
|
||||
if ("setuptools" in sys.modules and
|
||||
getattr(sys.modules["setuptools"],
|
||||
"_TWISTED_NO_CHECK_REQUIREMENTS", None) is not None):
|
||||
# Skip requirement checks, setuptools ought to take care of installing
|
||||
# the dependencies.
|
||||
return
|
||||
|
||||
# Don't allow the user to run with a version of zope.interface we don't
|
||||
# support.
|
||||
required = "Twisted requires zope.interface %s or later" % (required,)
|
||||
try:
|
||||
from zope import interface
|
||||
except ImportError:
|
||||
# It isn't installed.
|
||||
raise ImportError(required + ": no module named zope.interface.")
|
||||
except:
|
||||
# It is installed but not compatible with this version of Python.
|
||||
raise ImportError(required + ".")
|
||||
try:
|
||||
# Try using the API that we need, which only works right with
|
||||
# zope.interface 3.6 (or 4.0 on Python 3)
|
||||
class IDummy(interface.Interface):
|
||||
pass
|
||||
@interface.implementer(IDummy)
|
||||
class Dummy(object):
|
||||
pass
|
||||
except TypeError:
|
||||
# It is installed but not compatible with this version of Python.
|
||||
raise ImportError(required + ".")
|
||||
|
||||
_checkRequirements()
|
||||
|
||||
# Ensure compat gets imported
|
||||
from twisted.python import compat
|
||||
|
||||
# setup version
|
||||
from twisted._version import version
|
||||
__version__ = version.short()
|
||||
|
||||
del compat
|
||||
|
||||
# Deprecating lore.
|
||||
from twisted.python.versions import Version
|
||||
from twisted.python.deprecate import deprecatedModuleAttribute
|
||||
|
||||
deprecatedModuleAttribute(
|
||||
Version("Twisted", 14, 0, 0),
|
||||
"Use Sphinx instead.",
|
||||
"twisted", "lore")
|
@ -0,0 +1,11 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# This is an auto-generated file. Do not edit it.
|
||||
|
||||
"""
|
||||
Provides Twisted version information.
|
||||
"""
|
||||
|
||||
from twisted.python import versions
|
||||
version = versions.Version('twisted', 15, 2, 1)
|
@ -0,0 +1,6 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Configuration objects for Twisted Applications.
|
||||
"""
|
@ -0,0 +1,678 @@
|
||||
# -*- test-case-name: twisted.test.test_application,twisted.test.test_twistd -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import os
|
||||
import pdb
|
||||
import getpass
|
||||
import traceback
|
||||
import signal
|
||||
from operator import attrgetter
|
||||
|
||||
from twisted.python import runtime, log, usage, failure, util, logfile
|
||||
from twisted.python.reflect import qual, namedAny
|
||||
from twisted.python.log import ILogObserver
|
||||
from twisted.persisted import sob
|
||||
from twisted.application import service, reactors
|
||||
from twisted.internet import defer
|
||||
from twisted import copyright, plugin
|
||||
|
||||
# Expose the new implementation of installReactor at the old location.
|
||||
from twisted.application.reactors import installReactor
|
||||
from twisted.application.reactors import NoSuchReactor
|
||||
|
||||
|
||||
class _BasicProfiler(object):
|
||||
"""
|
||||
@ivar saveStats: if C{True}, save the stats information instead of the
|
||||
human readable format
|
||||
@type saveStats: C{bool}
|
||||
|
||||
@ivar profileOutput: the name of the file use to print profile data.
|
||||
@type profileOutput: C{str}
|
||||
"""
|
||||
|
||||
def __init__(self, profileOutput, saveStats):
|
||||
self.profileOutput = profileOutput
|
||||
self.saveStats = saveStats
|
||||
|
||||
|
||||
def _reportImportError(self, module, e):
|
||||
"""
|
||||
Helper method to report an import error with a profile module. This
|
||||
has to be explicit because some of these modules are removed by
|
||||
distributions due to them being non-free.
|
||||
"""
|
||||
s = "Failed to import module %s: %s" % (module, e)
|
||||
s += """
|
||||
This is most likely caused by your operating system not including
|
||||
the module due to it being non-free. Either do not use the option
|
||||
--profile, or install the module; your operating system vendor
|
||||
may provide it in a separate package.
|
||||
"""
|
||||
raise SystemExit(s)
|
||||
|
||||
|
||||
|
||||
class ProfileRunner(_BasicProfiler):
|
||||
"""
|
||||
Runner for the standard profile module.
|
||||
"""
|
||||
|
||||
def run(self, reactor):
|
||||
"""
|
||||
Run reactor under the standard profiler.
|
||||
"""
|
||||
try:
|
||||
import profile
|
||||
except ImportError as e:
|
||||
self._reportImportError("profile", e)
|
||||
|
||||
p = profile.Profile()
|
||||
p.runcall(reactor.run)
|
||||
if self.saveStats:
|
||||
p.dump_stats(self.profileOutput)
|
||||
else:
|
||||
tmp, sys.stdout = sys.stdout, open(self.profileOutput, 'a')
|
||||
try:
|
||||
p.print_stats()
|
||||
finally:
|
||||
sys.stdout, tmp = tmp, sys.stdout
|
||||
tmp.close()
|
||||
|
||||
|
||||
|
||||
class HotshotRunner(_BasicProfiler):
|
||||
"""
|
||||
Runner for the hotshot profile module.
|
||||
"""
|
||||
|
||||
def run(self, reactor):
|
||||
"""
|
||||
Run reactor under the hotshot profiler.
|
||||
"""
|
||||
try:
|
||||
import hotshot.stats
|
||||
except (ImportError, SystemExit) as e:
|
||||
# Certain versions of Debian (and Debian derivatives) raise
|
||||
# SystemExit when importing hotshot if the "non-free" profiler
|
||||
# module is not installed. Someone eventually recognized this
|
||||
# as a bug and changed the Debian packaged Python to raise
|
||||
# ImportError instead. Handle both exception types here in
|
||||
# order to support the versions of Debian which have this
|
||||
# behavior. The bug report which prompted the introduction of
|
||||
# this highly undesirable behavior should be available online at
|
||||
# <http://bugs.debian.org/cgi-bin/bugreport.cgi?bug=334067>.
|
||||
# There seems to be no corresponding bug report which resulted
|
||||
# in the behavior being removed. -exarkun
|
||||
self._reportImportError("hotshot", e)
|
||||
|
||||
# this writes stats straight out
|
||||
p = hotshot.Profile(self.profileOutput)
|
||||
p.runcall(reactor.run)
|
||||
if self.saveStats:
|
||||
# stats are automatically written to file, nothing to do
|
||||
return
|
||||
else:
|
||||
s = hotshot.stats.load(self.profileOutput)
|
||||
s.strip_dirs()
|
||||
s.sort_stats(-1)
|
||||
s.stream = open(self.profileOutput, 'w')
|
||||
s.print_stats()
|
||||
s.stream.close()
|
||||
|
||||
|
||||
|
||||
class CProfileRunner(_BasicProfiler):
|
||||
"""
|
||||
Runner for the cProfile module.
|
||||
"""
|
||||
|
||||
def run(self, reactor):
|
||||
"""
|
||||
Run reactor under the cProfile profiler.
|
||||
"""
|
||||
try:
|
||||
import cProfile
|
||||
import pstats
|
||||
except ImportError as e:
|
||||
self._reportImportError("cProfile", e)
|
||||
|
||||
p = cProfile.Profile()
|
||||
p.runcall(reactor.run)
|
||||
if self.saveStats:
|
||||
p.dump_stats(self.profileOutput)
|
||||
else:
|
||||
stream = open(self.profileOutput, 'w')
|
||||
s = pstats.Stats(p, stream=stream)
|
||||
s.strip_dirs()
|
||||
s.sort_stats(-1)
|
||||
s.print_stats()
|
||||
stream.close()
|
||||
|
||||
|
||||
|
||||
class AppProfiler(object):
|
||||
"""
|
||||
Class which selects a specific profile runner based on configuration
|
||||
options.
|
||||
|
||||
@ivar profiler: the name of the selected profiler.
|
||||
@type profiler: C{str}
|
||||
"""
|
||||
profilers = {"profile": ProfileRunner, "hotshot": HotshotRunner,
|
||||
"cprofile": CProfileRunner}
|
||||
|
||||
def __init__(self, options):
|
||||
saveStats = options.get("savestats", False)
|
||||
profileOutput = options.get("profile", None)
|
||||
self.profiler = options.get("profiler", "hotshot").lower()
|
||||
if self.profiler in self.profilers:
|
||||
profiler = self.profilers[self.profiler](profileOutput, saveStats)
|
||||
self.run = profiler.run
|
||||
else:
|
||||
raise SystemExit("Unsupported profiler name: %s" %
|
||||
(self.profiler,))
|
||||
|
||||
|
||||
|
||||
class AppLogger(object):
|
||||
"""
|
||||
An L{AppLogger} attaches the configured log observer specified on the
|
||||
commandline to a L{ServerOptions} object or the custom L{ILogObserver}.
|
||||
|
||||
@ivar _logfilename: The name of the file to which to log, if other than the
|
||||
default.
|
||||
@type _logfilename: C{str}
|
||||
|
||||
@ivar _observerFactory: Callable object that will create a log observer, or
|
||||
None.
|
||||
|
||||
@ivar _observer: log observer added at C{start} and removed at C{stop}.
|
||||
@type _observer: C{callable}
|
||||
"""
|
||||
_observer = None
|
||||
|
||||
def __init__(self, options):
|
||||
"""
|
||||
Initialize an L{AppLogger} with a L{ServerOptions}.
|
||||
"""
|
||||
self._logfilename = options.get("logfile", "")
|
||||
self._observerFactory = options.get("logger") or None
|
||||
|
||||
|
||||
def start(self, application):
|
||||
"""
|
||||
Initialize the global logging system for the given application.
|
||||
|
||||
If a custom logger was specified on the command line it will be used.
|
||||
If not, and an L{ILogObserver} component has been set on
|
||||
C{application}, then it will be used as the log observer. Otherwise a
|
||||
log observer will be created based on the command-line options for
|
||||
built-in loggers (e.g. C{--logfile}).
|
||||
|
||||
@param application: The application on which to check for an
|
||||
L{ILogObserver}.
|
||||
@type application: L{twisted.python.components.Componentized}
|
||||
"""
|
||||
if self._observerFactory is not None:
|
||||
observer = self._observerFactory()
|
||||
else:
|
||||
observer = application.getComponent(ILogObserver, None)
|
||||
|
||||
if observer is None:
|
||||
observer = self._getLogObserver()
|
||||
self._observer = observer
|
||||
log.startLoggingWithObserver(self._observer)
|
||||
self._initialLog()
|
||||
|
||||
|
||||
def _initialLog(self):
|
||||
"""
|
||||
Print twistd start log message.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
log.msg("twistd %s (%s %s) starting up." % (
|
||||
copyright.version, sys.executable, runtime.shortPythonVersion())
|
||||
)
|
||||
log.msg('reactor class: %s.' % (qual(reactor.__class__),))
|
||||
|
||||
|
||||
def _getLogObserver(self):
|
||||
"""
|
||||
Create a log observer to be added to the logging system before running
|
||||
this application.
|
||||
"""
|
||||
if self._logfilename == '-' or not self._logfilename:
|
||||
logFile = sys.stdout
|
||||
else:
|
||||
logFile = logfile.LogFile.fromFullPath(self._logfilename)
|
||||
return log.FileLogObserver(logFile).emit
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""
|
||||
Remove all log observers previously set up by L{AppLogger.start}.
|
||||
"""
|
||||
log.msg("Server Shut Down.")
|
||||
if self._observer is not None:
|
||||
log.removeObserver(self._observer)
|
||||
self._observer = None
|
||||
|
||||
|
||||
|
||||
def fixPdb():
|
||||
def do_stop(self, arg):
|
||||
self.clear_all_breaks()
|
||||
self.set_continue()
|
||||
from twisted.internet import reactor
|
||||
reactor.callLater(0, reactor.stop)
|
||||
return 1
|
||||
|
||||
|
||||
def help_stop(self):
|
||||
print("stop - Continue execution, then cleanly shutdown the twisted "
|
||||
"reactor.")
|
||||
|
||||
|
||||
def set_quit(self):
|
||||
os._exit(0)
|
||||
|
||||
pdb.Pdb.set_quit = set_quit
|
||||
pdb.Pdb.do_stop = do_stop
|
||||
pdb.Pdb.help_stop = help_stop
|
||||
|
||||
|
||||
|
||||
def runReactorWithLogging(config, oldstdout, oldstderr, profiler=None,
|
||||
reactor=None):
|
||||
"""
|
||||
Start the reactor, using profiling if specified by the configuration, and
|
||||
log any error happening in the process.
|
||||
|
||||
@param config: configuration of the twistd application.
|
||||
@type config: L{ServerOptions}
|
||||
|
||||
@param oldstdout: initial value of C{sys.stdout}.
|
||||
@type oldstdout: C{file}
|
||||
|
||||
@param oldstderr: initial value of C{sys.stderr}.
|
||||
@type oldstderr: C{file}
|
||||
|
||||
@param profiler: object used to run the reactor with profiling.
|
||||
@type profiler: L{AppProfiler}
|
||||
|
||||
@param reactor: The reactor to use. If C{None}, the global reactor will
|
||||
be used.
|
||||
"""
|
||||
if reactor is None:
|
||||
from twisted.internet import reactor
|
||||
try:
|
||||
if config['profile']:
|
||||
if profiler is not None:
|
||||
profiler.run(reactor)
|
||||
elif config['debug']:
|
||||
sys.stdout = oldstdout
|
||||
sys.stderr = oldstderr
|
||||
if runtime.platformType == 'posix':
|
||||
signal.signal(signal.SIGUSR2, lambda *args: pdb.set_trace())
|
||||
signal.signal(signal.SIGINT, lambda *args: pdb.set_trace())
|
||||
fixPdb()
|
||||
pdb.runcall(reactor.run)
|
||||
else:
|
||||
reactor.run()
|
||||
except:
|
||||
if config['nodaemon']:
|
||||
file = oldstdout
|
||||
else:
|
||||
file = open("TWISTD-CRASH.log", "a")
|
||||
traceback.print_exc(file=file)
|
||||
file.flush()
|
||||
|
||||
|
||||
|
||||
def getPassphrase(needed):
|
||||
if needed:
|
||||
return getpass.getpass('Passphrase: ')
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def getSavePassphrase(needed):
|
||||
if needed:
|
||||
return util.getPassword("Encryption passphrase: ")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class ApplicationRunner(object):
|
||||
"""
|
||||
An object which helps running an application based on a config object.
|
||||
|
||||
Subclass me and implement preApplication and postApplication
|
||||
methods. postApplication generally will want to run the reactor
|
||||
after starting the application.
|
||||
|
||||
@ivar config: The config object, which provides a dict-like interface.
|
||||
|
||||
@ivar application: Available in postApplication, but not
|
||||
preApplication. This is the application object.
|
||||
|
||||
@ivar profilerFactory: Factory for creating a profiler object, able to
|
||||
profile the application if options are set accordingly.
|
||||
|
||||
@ivar profiler: Instance provided by C{profilerFactory}.
|
||||
|
||||
@ivar loggerFactory: Factory for creating object responsible for logging.
|
||||
|
||||
@ivar logger: Instance provided by C{loggerFactory}.
|
||||
"""
|
||||
profilerFactory = AppProfiler
|
||||
loggerFactory = AppLogger
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.profiler = self.profilerFactory(config)
|
||||
self.logger = self.loggerFactory(config)
|
||||
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Run the application.
|
||||
"""
|
||||
self.preApplication()
|
||||
self.application = self.createOrGetApplication()
|
||||
|
||||
self.logger.start(self.application)
|
||||
|
||||
self.postApplication()
|
||||
self.logger.stop()
|
||||
|
||||
|
||||
def startReactor(self, reactor, oldstdout, oldstderr):
|
||||
"""
|
||||
Run the reactor with the given configuration. Subclasses should
|
||||
probably call this from C{postApplication}.
|
||||
|
||||
@see: L{runReactorWithLogging}
|
||||
"""
|
||||
runReactorWithLogging(
|
||||
self.config, oldstdout, oldstderr, self.profiler, reactor)
|
||||
|
||||
|
||||
def preApplication(self):
|
||||
"""
|
||||
Override in subclass.
|
||||
|
||||
This should set up any state necessary before loading and
|
||||
running the Application.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def postApplication(self):
|
||||
"""
|
||||
Override in subclass.
|
||||
|
||||
This will be called after the application has been loaded (so
|
||||
the C{application} attribute will be set). Generally this
|
||||
should start the application and run the reactor.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def createOrGetApplication(self):
|
||||
"""
|
||||
Create or load an Application based on the parameters found in the
|
||||
given L{ServerOptions} instance.
|
||||
|
||||
If a subcommand was used, the L{service.IServiceMaker} that it
|
||||
represents will be used to construct a service to be added to
|
||||
a newly-created Application.
|
||||
|
||||
Otherwise, an application will be loaded based on parameters in
|
||||
the config.
|
||||
"""
|
||||
if self.config.subCommand:
|
||||
# If a subcommand was given, it's our responsibility to create
|
||||
# the application, instead of load it from a file.
|
||||
|
||||
# loadedPlugins is set up by the ServerOptions.subCommands
|
||||
# property, which is iterated somewhere in the bowels of
|
||||
# usage.Options.
|
||||
plg = self.config.loadedPlugins[self.config.subCommand]
|
||||
ser = plg.makeService(self.config.subOptions)
|
||||
application = service.Application(plg.tapname)
|
||||
ser.setServiceParent(application)
|
||||
else:
|
||||
passphrase = getPassphrase(self.config['encrypted'])
|
||||
application = getApplication(self.config, passphrase)
|
||||
return application
|
||||
|
||||
|
||||
|
||||
def getApplication(config, passphrase):
|
||||
s = [(config[t], t)
|
||||
for t in ['python', 'source', 'file'] if config[t]][0]
|
||||
filename, style = s[0], {'file': 'pickle'}.get(s[1], s[1])
|
||||
try:
|
||||
log.msg("Loading %s..." % filename)
|
||||
application = service.loadApplication(filename, style, passphrase)
|
||||
log.msg("Loaded.")
|
||||
except Exception as e:
|
||||
s = "Failed to load application: %s" % e
|
||||
if isinstance(e, KeyError) and e.args[0] == "application":
|
||||
s += """
|
||||
Could not find 'application' in the file. To use 'twistd -y', your .tac
|
||||
file must create a suitable object (e.g., by calling service.Application())
|
||||
and store it in a variable named 'application'. twistd loads your .tac file
|
||||
and scans the global variables for one of this name.
|
||||
|
||||
Please read the 'Using Application' HOWTO for details.
|
||||
"""
|
||||
traceback.print_exc(file=log.logfile)
|
||||
log.msg(s)
|
||||
log.deferr()
|
||||
sys.exit('\n' + s + '\n')
|
||||
return application
|
||||
|
||||
|
||||
|
||||
def _reactorAction():
|
||||
return usage.CompleteList([r.shortName for r in
|
||||
reactors.getReactorTypes()])
|
||||
|
||||
|
||||
|
||||
class ReactorSelectionMixin:
|
||||
"""
|
||||
Provides options for selecting a reactor to install.
|
||||
|
||||
If a reactor is installed, the short name which was used to locate it is
|
||||
saved as the value for the C{"reactor"} key.
|
||||
"""
|
||||
compData = usage.Completions(
|
||||
optActions={"reactor": _reactorAction})
|
||||
|
||||
messageOutput = sys.stdout
|
||||
_getReactorTypes = staticmethod(reactors.getReactorTypes)
|
||||
|
||||
|
||||
def opt_help_reactors(self):
|
||||
"""
|
||||
Display a list of possibly available reactor names.
|
||||
"""
|
||||
rcts = sorted(self._getReactorTypes(), key=attrgetter('shortName'))
|
||||
for r in rcts:
|
||||
self.messageOutput.write(' %-4s\t%s\n' %
|
||||
(r.shortName, r.description))
|
||||
raise SystemExit(0)
|
||||
|
||||
|
||||
def opt_reactor(self, shortName):
|
||||
"""
|
||||
Which reactor to use (see --help-reactors for a list of possibilities)
|
||||
"""
|
||||
# Actually actually actually install the reactor right at this very
|
||||
# moment, before any other code (for example, a sub-command plugin)
|
||||
# runs and accidentally imports and installs the default reactor.
|
||||
#
|
||||
# This could probably be improved somehow.
|
||||
try:
|
||||
installReactor(shortName)
|
||||
except NoSuchReactor:
|
||||
msg = ("The specified reactor does not exist: '%s'.\n"
|
||||
"See the list of available reactors with "
|
||||
"--help-reactors" % (shortName,))
|
||||
raise usage.UsageError(msg)
|
||||
except Exception as e:
|
||||
msg = ("The specified reactor cannot be used, failed with error: "
|
||||
"%s.\nSee the list of available reactors with "
|
||||
"--help-reactors" % (e,))
|
||||
raise usage.UsageError(msg)
|
||||
else:
|
||||
self["reactor"] = shortName
|
||||
opt_r = opt_reactor
|
||||
|
||||
|
||||
|
||||
class ServerOptions(usage.Options, ReactorSelectionMixin):
|
||||
|
||||
longdesc = ("twistd reads a twisted.application.service.Application out "
|
||||
"of a file and runs it.")
|
||||
|
||||
optFlags = [['savestats', None,
|
||||
"save the Stats object rather than the text output of "
|
||||
"the profiler."],
|
||||
['no_save', 'o', "do not save state on shutdown"],
|
||||
['encrypted', 'e',
|
||||
"The specified tap/aos file is encrypted."]]
|
||||
|
||||
optParameters = [['logfile', 'l', None,
|
||||
"log to a specified file, - for stdout"],
|
||||
['logger', None, None,
|
||||
"A fully-qualified name to a log observer factory to "
|
||||
"use for the initial log observer. Takes precedence "
|
||||
"over --logfile and --syslog (when available)."],
|
||||
['profile', 'p', None,
|
||||
"Run in profile mode, dumping results to specified "
|
||||
"file."],
|
||||
['profiler', None, "hotshot",
|
||||
"Name of the profiler to use (%s)." %
|
||||
", ".join(AppProfiler.profilers)],
|
||||
['file', 'f', 'twistd.tap',
|
||||
"read the given .tap file"],
|
||||
['python', 'y', None,
|
||||
"read an application from within a Python file "
|
||||
"(implies -o)"],
|
||||
['source', 's', None,
|
||||
"Read an application from a .tas file (AOT format)."],
|
||||
['rundir', 'd', '.',
|
||||
'Change to a supplied directory before running']]
|
||||
|
||||
compData = usage.Completions(
|
||||
mutuallyExclusive=[("file", "python", "source")],
|
||||
optActions={"file": usage.CompleteFiles("*.tap"),
|
||||
"python": usage.CompleteFiles("*.(tac|py)"),
|
||||
"source": usage.CompleteFiles("*.tas"),
|
||||
"rundir": usage.CompleteDirs()}
|
||||
)
|
||||
|
||||
_getPlugins = staticmethod(plugin.getPlugins)
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
self['debug'] = False
|
||||
usage.Options.__init__(self, *a, **kw)
|
||||
|
||||
|
||||
def opt_debug(self):
|
||||
"""
|
||||
Run the application in the Python Debugger (implies nodaemon),
|
||||
sending SIGUSR2 will drop into debugger
|
||||
"""
|
||||
defer.setDebugging(True)
|
||||
failure.startDebugMode()
|
||||
self['debug'] = True
|
||||
opt_b = opt_debug
|
||||
|
||||
|
||||
def opt_spew(self):
|
||||
"""
|
||||
Print an insanely verbose log of everything that happens.
|
||||
Useful when debugging freezes or locks in complex code."""
|
||||
sys.settrace(util.spewer)
|
||||
try:
|
||||
import threading
|
||||
except ImportError:
|
||||
return
|
||||
threading.settrace(util.spewer)
|
||||
|
||||
|
||||
def parseOptions(self, options=None):
|
||||
if options is None:
|
||||
options = sys.argv[1:] or ["--help"]
|
||||
usage.Options.parseOptions(self, options)
|
||||
|
||||
|
||||
def postOptions(self):
|
||||
if self.subCommand or self['python']:
|
||||
self['no_save'] = True
|
||||
if self['logger'] is not None:
|
||||
try:
|
||||
self['logger'] = namedAny(self['logger'])
|
||||
except Exception as e:
|
||||
raise usage.UsageError("Logger '%s' could not be imported: %s"
|
||||
% (self['logger'], e))
|
||||
|
||||
|
||||
def subCommands(self):
|
||||
plugins = self._getPlugins(service.IServiceMaker)
|
||||
self.loadedPlugins = {}
|
||||
for plug in sorted(plugins, key=attrgetter('tapname')):
|
||||
self.loadedPlugins[plug.tapname] = plug
|
||||
yield (plug.tapname,
|
||||
None,
|
||||
# Avoid resolving the options attribute right away, in case
|
||||
# it's a property with a non-trivial getter (eg, one which
|
||||
# imports modules).
|
||||
lambda plug=plug: plug.options(),
|
||||
plug.description)
|
||||
subCommands = property(subCommands)
|
||||
|
||||
|
||||
|
||||
def run(runApp, ServerOptions):
|
||||
config = ServerOptions()
|
||||
try:
|
||||
config.parseOptions()
|
||||
except usage.error as ue:
|
||||
print(config)
|
||||
print("%s: %s" % (sys.argv[0], ue))
|
||||
else:
|
||||
runApp(config)
|
||||
|
||||
|
||||
|
||||
def convertStyle(filein, typein, passphrase, fileout, typeout, encrypt):
|
||||
application = service.loadApplication(filein, typein, passphrase)
|
||||
sob.IPersistable(application).setStyle(typeout)
|
||||
passphrase = getSavePassphrase(encrypt)
|
||||
if passphrase:
|
||||
fileout = None
|
||||
sob.IPersistable(application).save(filename=fileout, passphrase=passphrase)
|
||||
|
||||
|
||||
|
||||
def startApplication(application, save):
|
||||
from twisted.internet import reactor
|
||||
service.IService(application).startService()
|
||||
if save:
|
||||
p = sob.IPersistable(application)
|
||||
reactor.addSystemEventTrigger('after', 'shutdown', p.save, 'shutdown')
|
||||
reactor.addSystemEventTrigger('before', 'shutdown',
|
||||
service.IService(application).stopService)
|
@ -0,0 +1,396 @@
|
||||
# -*- test-case-name: twisted.application.test.test_internet,twisted.test.test_application,twisted.test.test_cooperator -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Reactor-based Services
|
||||
|
||||
Here are services to run clients, servers and periodic services using
|
||||
the reactor.
|
||||
|
||||
If you want to run a server service, L{StreamServerEndpointService} defines a
|
||||
service that can wrap an arbitrary L{IStreamServerEndpoint
|
||||
<twisted.internet.interfaces.IStreamServerEndpoint>}
|
||||
as an L{IService}. See also L{twisted.application.strports.service} for
|
||||
constructing one of these directly from a descriptive string.
|
||||
|
||||
Additionally, this module (dynamically) defines various Service subclasses that
|
||||
let you represent clients and servers in a Service hierarchy. Endpoints APIs
|
||||
should be preferred for stream server services, but since those APIs do not yet
|
||||
exist for clients or datagram services, many of these are still useful.
|
||||
|
||||
They are as follows::
|
||||
|
||||
TCPServer, TCPClient,
|
||||
UNIXServer, UNIXClient,
|
||||
SSLServer, SSLClient,
|
||||
UDPServer,
|
||||
UNIXDatagramServer, UNIXDatagramClient,
|
||||
MulticastServer
|
||||
|
||||
These classes take arbitrary arguments in their constructors and pass
|
||||
them straight on to their respective reactor.listenXXX or
|
||||
reactor.connectXXX calls.
|
||||
|
||||
For example, the following service starts a web server on port 8080:
|
||||
C{TCPServer(8080, server.Site(r))}. See the documentation for the
|
||||
reactor.listen/connect* methods for more information.
|
||||
"""
|
||||
|
||||
from twisted.python import log
|
||||
from twisted.application import service
|
||||
from twisted.internet import task
|
||||
|
||||
from twisted.internet.defer import CancelledError
|
||||
|
||||
|
||||
def _maybeGlobalReactor(maybeReactor):
|
||||
"""
|
||||
@return: the argument, or the global reactor if the argument is C{None}.
|
||||
"""
|
||||
if maybeReactor is None:
|
||||
from twisted.internet import reactor
|
||||
return reactor
|
||||
else:
|
||||
return maybeReactor
|
||||
|
||||
|
||||
class _VolatileDataService(service.Service):
|
||||
|
||||
volatile = []
|
||||
|
||||
def __getstate__(self):
|
||||
d = service.Service.__getstate__(self)
|
||||
for attr in self.volatile:
|
||||
if attr in d:
|
||||
del d[attr]
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class _AbstractServer(_VolatileDataService):
|
||||
"""
|
||||
@cvar volatile: list of attribute to remove from pickling.
|
||||
@type volatile: C{list}
|
||||
|
||||
@ivar method: the type of method to call on the reactor, one of B{TCP},
|
||||
B{UDP}, B{SSL} or B{UNIX}.
|
||||
@type method: C{str}
|
||||
|
||||
@ivar reactor: the current running reactor.
|
||||
@type reactor: a provider of C{IReactorTCP}, C{IReactorUDP},
|
||||
C{IReactorSSL} or C{IReactorUnix}.
|
||||
|
||||
@ivar _port: instance of port set when the service is started.
|
||||
@type _port: a provider of L{twisted.internet.interfaces.IListeningPort}.
|
||||
"""
|
||||
|
||||
volatile = ['_port']
|
||||
method = None
|
||||
reactor = None
|
||||
|
||||
_port = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
if 'reactor' in kwargs:
|
||||
self.reactor = kwargs.pop("reactor")
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
def privilegedStartService(self):
|
||||
service.Service.privilegedStartService(self)
|
||||
self._port = self._getPort()
|
||||
|
||||
|
||||
def startService(self):
|
||||
service.Service.startService(self)
|
||||
if self._port is None:
|
||||
self._port = self._getPort()
|
||||
|
||||
|
||||
def stopService(self):
|
||||
service.Service.stopService(self)
|
||||
# TODO: if startup failed, should shutdown skip stopListening?
|
||||
# _port won't exist
|
||||
if self._port is not None:
|
||||
d = self._port.stopListening()
|
||||
del self._port
|
||||
return d
|
||||
|
||||
|
||||
def _getPort(self):
|
||||
"""
|
||||
Wrapper around the appropriate listen method of the reactor.
|
||||
|
||||
@return: the port object returned by the listen method.
|
||||
@rtype: an object providing
|
||||
L{twisted.internet.interfaces.IListeningPort}.
|
||||
"""
|
||||
return getattr(_maybeGlobalReactor(self.reactor),
|
||||
'listen%s' % (self.method,))(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
|
||||
class _AbstractClient(_VolatileDataService):
|
||||
"""
|
||||
@cvar volatile: list of attribute to remove from pickling.
|
||||
@type volatile: C{list}
|
||||
|
||||
@ivar method: the type of method to call on the reactor, one of B{TCP},
|
||||
B{UDP}, B{SSL} or B{UNIX}.
|
||||
@type method: C{str}
|
||||
|
||||
@ivar reactor: the current running reactor.
|
||||
@type reactor: a provider of C{IReactorTCP}, C{IReactorUDP},
|
||||
C{IReactorSSL} or C{IReactorUnix}.
|
||||
|
||||
@ivar _connection: instance of connection set when the service is started.
|
||||
@type _connection: a provider of L{twisted.internet.interfaces.IConnector}.
|
||||
"""
|
||||
volatile = ['_connection']
|
||||
method = None
|
||||
reactor = None
|
||||
|
||||
_connection = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
if 'reactor' in kwargs:
|
||||
self.reactor = kwargs.pop("reactor")
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
def startService(self):
|
||||
service.Service.startService(self)
|
||||
self._connection = self._getConnection()
|
||||
|
||||
|
||||
def stopService(self):
|
||||
service.Service.stopService(self)
|
||||
if self._connection is not None:
|
||||
self._connection.disconnect()
|
||||
del self._connection
|
||||
|
||||
|
||||
def _getConnection(self):
|
||||
"""
|
||||
Wrapper around the appropriate connect method of the reactor.
|
||||
|
||||
@return: the port object returned by the connect method.
|
||||
@rtype: an object providing L{twisted.internet.interfaces.IConnector}.
|
||||
"""
|
||||
return getattr(_maybeGlobalReactor(self.reactor),
|
||||
'connect%s' % (self.method,))(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
|
||||
_doc={
|
||||
'Client':
|
||||
"""Connect to %(tran)s
|
||||
|
||||
Call reactor.connect%(tran)s when the service starts, with the
|
||||
arguments given to the constructor.
|
||||
""",
|
||||
'Server':
|
||||
"""Serve %(tran)s clients
|
||||
|
||||
Call reactor.listen%(tran)s when the service starts, with the
|
||||
arguments given to the constructor. When the service stops,
|
||||
stop listening. See twisted.internet.interfaces for documentation
|
||||
on arguments to the reactor method.
|
||||
""",
|
||||
}
|
||||
|
||||
import types
|
||||
for tran in 'TCP UNIX SSL UDP UNIXDatagram Multicast'.split():
|
||||
for side in 'Server Client'.split():
|
||||
if tran == "Multicast" and side == "Client":
|
||||
continue
|
||||
if tran == "UDP" and side == "Client":
|
||||
continue
|
||||
base = globals()['_Abstract'+side]
|
||||
doc = _doc[side] % vars()
|
||||
klass = types.ClassType(tran+side, (base,),
|
||||
{'method': tran, '__doc__': doc})
|
||||
globals()[tran+side] = klass
|
||||
|
||||
|
||||
|
||||
class TimerService(_VolatileDataService):
|
||||
"""
|
||||
Service to periodically call a function
|
||||
|
||||
Every C{step} seconds call the given function with the given arguments.
|
||||
The service starts the calls when it starts, and cancels them
|
||||
when it stops.
|
||||
|
||||
@ivar clock: Source of time. This defaults to L{None} which is
|
||||
causes L{twisted.internet.reactor} to be used.
|
||||
Feel free to set this to something else, but it probably ought to be
|
||||
set *before* calling L{startService}.
|
||||
@type clock: L{IReactorTime<twisted.internet.interfaces.IReactorTime>}
|
||||
|
||||
@ivar call: Function and arguments to call periodically.
|
||||
@type call: L{tuple} of C{(callable, args, kwargs)}
|
||||
"""
|
||||
|
||||
volatile = ['_loop', '_loopFinished']
|
||||
|
||||
def __init__(self, step, callable, *args, **kwargs):
|
||||
"""
|
||||
@param step: The number of seconds between calls.
|
||||
@type step: L{float}
|
||||
|
||||
@param callable: Function to call
|
||||
@type callable: L{callable}
|
||||
|
||||
@param args: Positional arguments to pass to function
|
||||
@param kwargs: Keyword arguments to pass to function
|
||||
"""
|
||||
self.step = step
|
||||
self.call = (callable, args, kwargs)
|
||||
self.clock = None
|
||||
|
||||
def startService(self):
|
||||
service.Service.startService(self)
|
||||
callable, args, kwargs = self.call
|
||||
# we have to make a new LoopingCall each time we're started, because
|
||||
# an active LoopingCall remains active when serialized. If
|
||||
# LoopingCall were a _VolatileDataService, we wouldn't need to do
|
||||
# this.
|
||||
self._loop = task.LoopingCall(callable, *args, **kwargs)
|
||||
self._loop.clock = _maybeGlobalReactor(self.clock)
|
||||
self._loopFinished = self._loop.start(self.step, now=True)
|
||||
self._loopFinished.addErrback(self._failed)
|
||||
|
||||
def _failed(self, why):
|
||||
# make a note that the LoopingCall is no longer looping, so we don't
|
||||
# try to shut it down a second time in stopService. I think this
|
||||
# should be in LoopingCall. -warner
|
||||
self._loop.running = False
|
||||
log.err(why)
|
||||
|
||||
def stopService(self):
|
||||
"""
|
||||
Stop the service.
|
||||
|
||||
@rtype: L{Deferred<defer.Deferred>}
|
||||
@return: a L{Deferred<defer.Deferred>} which is fired when the
|
||||
currently running call (if any) is finished.
|
||||
"""
|
||||
if self._loop.running:
|
||||
self._loop.stop()
|
||||
self._loopFinished.addCallback(lambda _:
|
||||
service.Service.stopService(self))
|
||||
return self._loopFinished
|
||||
|
||||
|
||||
|
||||
class CooperatorService(service.Service):
|
||||
"""
|
||||
Simple L{service.IService} which starts and stops a L{twisted.internet.task.Cooperator}.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.coop = task.Cooperator(started=False)
|
||||
|
||||
|
||||
def coiterate(self, iterator):
|
||||
return self.coop.coiterate(iterator)
|
||||
|
||||
|
||||
def startService(self):
|
||||
self.coop.start()
|
||||
|
||||
|
||||
def stopService(self):
|
||||
self.coop.stop()
|
||||
|
||||
|
||||
|
||||
class StreamServerEndpointService(service.Service, object):
|
||||
"""
|
||||
A L{StreamServerEndpointService} is an L{IService} which runs a server on a
|
||||
listening port described by an L{IStreamServerEndpoint
|
||||
<twisted.internet.interfaces.IStreamServerEndpoint>}.
|
||||
|
||||
@ivar factory: A server factory which will be used to listen on the
|
||||
endpoint.
|
||||
|
||||
@ivar endpoint: An L{IStreamServerEndpoint
|
||||
<twisted.internet.interfaces.IStreamServerEndpoint>} provider
|
||||
which will be used to listen when the service starts.
|
||||
|
||||
@ivar _waitingForPort: a Deferred, if C{listen} has yet been invoked on the
|
||||
endpoint, otherwise None.
|
||||
|
||||
@ivar _raiseSynchronously: Defines error-handling behavior for the case
|
||||
where C{listen(...)} raises an exception before C{startService} or
|
||||
C{privilegedStartService} have completed.
|
||||
|
||||
@type _raiseSynchronously: C{bool}
|
||||
|
||||
@since: 10.2
|
||||
"""
|
||||
|
||||
_raiseSynchronously = None
|
||||
|
||||
def __init__(self, endpoint, factory):
|
||||
self.endpoint = endpoint
|
||||
self.factory = factory
|
||||
self._waitingForPort = None
|
||||
|
||||
|
||||
def privilegedStartService(self):
|
||||
"""
|
||||
Start listening on the endpoint.
|
||||
"""
|
||||
service.Service.privilegedStartService(self)
|
||||
self._waitingForPort = self.endpoint.listen(self.factory)
|
||||
raisedNow = []
|
||||
def handleIt(err):
|
||||
if self._raiseSynchronously:
|
||||
raisedNow.append(err)
|
||||
elif not err.check(CancelledError):
|
||||
log.err(err)
|
||||
self._waitingForPort.addErrback(handleIt)
|
||||
if raisedNow:
|
||||
raisedNow[0].raiseException()
|
||||
|
||||
|
||||
def startService(self):
|
||||
"""
|
||||
Start listening on the endpoint, unless L{privilegedStartService} got
|
||||
around to it already.
|
||||
"""
|
||||
service.Service.startService(self)
|
||||
if self._waitingForPort is None:
|
||||
self.privilegedStartService()
|
||||
|
||||
|
||||
def stopService(self):
|
||||
"""
|
||||
Stop listening on the port if it is already listening, otherwise,
|
||||
cancel the attempt to listen.
|
||||
|
||||
@return: a L{Deferred<twisted.internet.defer.Deferred>} which fires
|
||||
with C{None} when the port has stopped listening.
|
||||
"""
|
||||
self._waitingForPort.cancel()
|
||||
def stopIt(port):
|
||||
if port is not None:
|
||||
return port.stopListening()
|
||||
d = self._waitingForPort.addCallback(stopIt)
|
||||
def stop(passthrough):
|
||||
self.running = False
|
||||
return passthrough
|
||||
d.addBoth(stop)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
__all__ = (['TimerService', 'CooperatorService', 'MulticastServer',
|
||||
'StreamServerEndpointService', 'UDPServer'] +
|
||||
[tran+side
|
||||
for tran in 'TCP UNIX SSL UNIXDatagram'.split()
|
||||
for side in 'Server Client'.split()])
|
@ -0,0 +1,84 @@
|
||||
# -*- test-case-name: twisted.test.test_application -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Plugin-based system for enumerating available reactors and installing one of
|
||||
them.
|
||||
"""
|
||||
|
||||
from zope.interface import Interface, Attribute, implementer
|
||||
|
||||
from twisted.plugin import IPlugin, getPlugins
|
||||
from twisted.python.reflect import namedAny
|
||||
|
||||
|
||||
class IReactorInstaller(Interface):
|
||||
"""
|
||||
Definition of a reactor which can probably be installed.
|
||||
"""
|
||||
shortName = Attribute("""
|
||||
A brief string giving the user-facing name of this reactor.
|
||||
""")
|
||||
|
||||
description = Attribute("""
|
||||
A longer string giving a user-facing description of this reactor.
|
||||
""")
|
||||
|
||||
def install():
|
||||
"""
|
||||
Install this reactor.
|
||||
"""
|
||||
|
||||
# TODO - A method which provides a best-guess as to whether this reactor
|
||||
# can actually be used in the execution environment.
|
||||
|
||||
|
||||
|
||||
class NoSuchReactor(KeyError):
|
||||
"""
|
||||
Raised when an attempt is made to install a reactor which cannot be found.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@implementer(IPlugin, IReactorInstaller)
|
||||
class Reactor(object):
|
||||
"""
|
||||
@ivar moduleName: The fully-qualified Python name of the module of which
|
||||
the install callable is an attribute.
|
||||
"""
|
||||
def __init__(self, shortName, moduleName, description):
|
||||
self.shortName = shortName
|
||||
self.moduleName = moduleName
|
||||
self.description = description
|
||||
|
||||
|
||||
def install(self):
|
||||
namedAny(self.moduleName).install()
|
||||
|
||||
|
||||
|
||||
def getReactorTypes():
|
||||
"""
|
||||
Return an iterator of L{IReactorInstaller} plugins.
|
||||
"""
|
||||
return getPlugins(IReactorInstaller)
|
||||
|
||||
|
||||
|
||||
def installReactor(shortName):
|
||||
"""
|
||||
Install the reactor with the given C{shortName} attribute.
|
||||
|
||||
@raise NoSuchReactor: If no reactor is found with a matching C{shortName}.
|
||||
|
||||
@raise: anything that the specified reactor can raise when installed.
|
||||
"""
|
||||
for installer in getReactorTypes():
|
||||
if installer.shortName == shortName:
|
||||
installer.install()
|
||||
from twisted.internet import reactor
|
||||
return reactor
|
||||
raise NoSuchReactor(shortName)
|
||||
|
@ -0,0 +1,411 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Service architecture for Twisted.
|
||||
|
||||
Services are arranged in a hierarchy. At the leafs of the hierarchy,
|
||||
the services which actually interact with the outside world are started.
|
||||
Services can be named or anonymous -- usually, they will be named if
|
||||
there is need to access them through the hierarchy (from a parent or
|
||||
a sibling).
|
||||
|
||||
Maintainer: Moshe Zadka
|
||||
"""
|
||||
|
||||
from zope.interface import implementer, Interface, Attribute
|
||||
|
||||
from twisted.python.reflect import namedAny
|
||||
from twisted.python import components
|
||||
from twisted.internet import defer
|
||||
from twisted.persisted import sob
|
||||
from twisted.plugin import IPlugin
|
||||
|
||||
|
||||
class IServiceMaker(Interface):
|
||||
"""
|
||||
An object which can be used to construct services in a flexible
|
||||
way.
|
||||
|
||||
This interface should most often be implemented along with
|
||||
L{twisted.plugin.IPlugin}, and will most often be used by the
|
||||
'twistd' command.
|
||||
"""
|
||||
tapname = Attribute(
|
||||
"A short string naming this Twisted plugin, for example 'web' or "
|
||||
"'pencil'. This name will be used as the subcommand of 'twistd'.")
|
||||
|
||||
description = Attribute(
|
||||
"A brief summary of the features provided by this "
|
||||
"Twisted application plugin.")
|
||||
|
||||
options = Attribute(
|
||||
"A C{twisted.python.usage.Options} subclass defining the "
|
||||
"configuration options for this application.")
|
||||
|
||||
|
||||
def makeService(options):
|
||||
"""
|
||||
Create and return an object providing
|
||||
L{twisted.application.service.IService}.
|
||||
|
||||
@param options: A mapping (typically a C{dict} or
|
||||
L{twisted.python.usage.Options} instance) of configuration
|
||||
options to desired configuration values.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@implementer(IPlugin, IServiceMaker)
|
||||
class ServiceMaker(object):
|
||||
"""
|
||||
Utility class to simplify the definition of L{IServiceMaker} plugins.
|
||||
"""
|
||||
def __init__(self, name, module, description, tapname):
|
||||
self.name = name
|
||||
self.module = module
|
||||
self.description = description
|
||||
self.tapname = tapname
|
||||
|
||||
|
||||
def options():
|
||||
def get(self):
|
||||
return namedAny(self.module).Options
|
||||
return get,
|
||||
options = property(*options())
|
||||
|
||||
|
||||
def makeService():
|
||||
def get(self):
|
||||
return namedAny(self.module).makeService
|
||||
return get,
|
||||
makeService = property(*makeService())
|
||||
|
||||
|
||||
|
||||
class IService(Interface):
|
||||
"""
|
||||
A service.
|
||||
|
||||
Run start-up and shut-down code at the appropriate times.
|
||||
|
||||
@type name: C{string}
|
||||
@ivar name: The name of the service (or None)
|
||||
@type running: C{boolean}
|
||||
@ivar running: Whether the service is running.
|
||||
"""
|
||||
|
||||
def setName(name):
|
||||
"""
|
||||
Set the name of the service.
|
||||
|
||||
@type name: C{str}
|
||||
@raise RuntimeError: Raised if the service already has a parent.
|
||||
"""
|
||||
|
||||
def setServiceParent(parent):
|
||||
"""
|
||||
Set the parent of the service. This method is responsible for setting
|
||||
the C{parent} attribute on this service (the child service).
|
||||
|
||||
@type parent: L{IServiceCollection}
|
||||
@raise RuntimeError: Raised if the service already has a parent
|
||||
or if the service has a name and the parent already has a child
|
||||
by that name.
|
||||
"""
|
||||
|
||||
def disownServiceParent():
|
||||
"""
|
||||
Use this API to remove an L{IService} from an L{IServiceCollection}.
|
||||
|
||||
This method is used symmetrically with L{setServiceParent} in that it
|
||||
sets the C{parent} attribute on the child.
|
||||
|
||||
@rtype: L{Deferred<defer.Deferred>}
|
||||
@return: a L{Deferred<defer.Deferred>} which is triggered when the
|
||||
service has finished shutting down. If shutting down is immediate,
|
||||
a value can be returned (usually, C{None}).
|
||||
"""
|
||||
|
||||
def startService():
|
||||
"""
|
||||
Start the service.
|
||||
"""
|
||||
|
||||
def stopService():
|
||||
"""
|
||||
Stop the service.
|
||||
|
||||
@rtype: L{Deferred<defer.Deferred>}
|
||||
@return: a L{Deferred<defer.Deferred>} which is triggered when the
|
||||
service has finished shutting down. If shutting down is immediate,
|
||||
a value can be returned (usually, C{None}).
|
||||
"""
|
||||
|
||||
def privilegedStartService():
|
||||
"""
|
||||
Do preparation work for starting the service.
|
||||
|
||||
Here things which should be done before changing directory,
|
||||
root or shedding privileges are done.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@implementer(IService)
|
||||
class Service:
|
||||
"""
|
||||
Base class for services.
|
||||
|
||||
Most services should inherit from this class. It handles the
|
||||
book-keeping responsibilities of starting and stopping, as well
|
||||
as not serializing this book-keeping information.
|
||||
"""
|
||||
|
||||
running = 0
|
||||
name = None
|
||||
parent = None
|
||||
|
||||
def __getstate__(self):
|
||||
dict = self.__dict__.copy()
|
||||
if "running" in dict:
|
||||
del dict['running']
|
||||
return dict
|
||||
|
||||
def setName(self, name):
|
||||
if self.parent is not None:
|
||||
raise RuntimeError("cannot change name when parent exists")
|
||||
self.name = name
|
||||
|
||||
def setServiceParent(self, parent):
|
||||
if self.parent is not None:
|
||||
self.disownServiceParent()
|
||||
parent = IServiceCollection(parent, parent)
|
||||
self.parent = parent
|
||||
self.parent.addService(self)
|
||||
|
||||
def disownServiceParent(self):
|
||||
d = self.parent.removeService(self)
|
||||
self.parent = None
|
||||
return d
|
||||
|
||||
def privilegedStartService(self):
|
||||
pass
|
||||
|
||||
def startService(self):
|
||||
self.running = 1
|
||||
|
||||
def stopService(self):
|
||||
self.running = 0
|
||||
|
||||
|
||||
|
||||
class IServiceCollection(Interface):
|
||||
"""
|
||||
Collection of services.
|
||||
|
||||
Contain several services, and manage their start-up/shut-down.
|
||||
Services can be accessed by name if they have a name, and it
|
||||
is always possible to iterate over them.
|
||||
"""
|
||||
|
||||
def getServiceNamed(name):
|
||||
"""
|
||||
Get the child service with a given name.
|
||||
|
||||
@type name: C{str}
|
||||
@rtype: L{IService}
|
||||
@raise KeyError: Raised if the service has no child with the
|
||||
given name.
|
||||
"""
|
||||
|
||||
def __iter__():
|
||||
"""
|
||||
Get an iterator over all child services.
|
||||
"""
|
||||
|
||||
def addService(service):
|
||||
"""
|
||||
Add a child service.
|
||||
|
||||
Only implementations of L{IService.setServiceParent} should use this
|
||||
method.
|
||||
|
||||
@type service: L{IService}
|
||||
@raise RuntimeError: Raised if the service has a child with
|
||||
the given name.
|
||||
"""
|
||||
|
||||
def removeService(service):
|
||||
"""
|
||||
Remove a child service.
|
||||
|
||||
Only implementations of L{IService.disownServiceParent} should
|
||||
use this method.
|
||||
|
||||
@type service: L{IService}
|
||||
@raise ValueError: Raised if the given service is not a child.
|
||||
@rtype: L{Deferred<defer.Deferred>}
|
||||
@return: a L{Deferred<defer.Deferred>} which is triggered when the
|
||||
service has finished shutting down. If shutting down is immediate,
|
||||
a value can be returned (usually, C{None}).
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@implementer(IServiceCollection)
|
||||
class MultiService(Service):
|
||||
"""
|
||||
Straightforward Service Container.
|
||||
|
||||
Hold a collection of services, and manage them in a simplistic
|
||||
way. No service will wait for another, but this object itself
|
||||
will not finish shutting down until all of its child services
|
||||
will finish.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.services = []
|
||||
self.namedServices = {}
|
||||
self.parent = None
|
||||
|
||||
def privilegedStartService(self):
|
||||
Service.privilegedStartService(self)
|
||||
for service in self:
|
||||
service.privilegedStartService()
|
||||
|
||||
def startService(self):
|
||||
Service.startService(self)
|
||||
for service in self:
|
||||
service.startService()
|
||||
|
||||
def stopService(self):
|
||||
Service.stopService(self)
|
||||
l = []
|
||||
services = list(self)
|
||||
services.reverse()
|
||||
for service in services:
|
||||
l.append(defer.maybeDeferred(service.stopService))
|
||||
return defer.DeferredList(l)
|
||||
|
||||
def getServiceNamed(self, name):
|
||||
return self.namedServices[name]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.services)
|
||||
|
||||
def addService(self, service):
|
||||
if service.name is not None:
|
||||
if service.name in self.namedServices:
|
||||
raise RuntimeError("cannot have two services with same name"
|
||||
" '%s'" % service.name)
|
||||
self.namedServices[service.name] = service
|
||||
self.services.append(service)
|
||||
if self.running:
|
||||
# It may be too late for that, but we will do our best
|
||||
service.privilegedStartService()
|
||||
service.startService()
|
||||
|
||||
def removeService(self, service):
|
||||
if service.name:
|
||||
del self.namedServices[service.name]
|
||||
self.services.remove(service)
|
||||
if self.running:
|
||||
# Returning this so as not to lose information from the
|
||||
# MultiService.stopService deferred.
|
||||
return service.stopService()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class IProcess(Interface):
|
||||
"""
|
||||
Process running parameters.
|
||||
|
||||
Represents parameters for how processes should be run.
|
||||
"""
|
||||
processName = Attribute(
|
||||
"""
|
||||
A C{str} giving the name the process should have in ps (or C{None}
|
||||
to leave the name alone).
|
||||
""")
|
||||
|
||||
uid = Attribute(
|
||||
"""
|
||||
An C{int} giving the user id as which the process should run (or
|
||||
C{None} to leave the UID alone).
|
||||
""")
|
||||
|
||||
gid = Attribute(
|
||||
"""
|
||||
An C{int} giving the group id as which the process should run (or
|
||||
C{None} to leave the GID alone).
|
||||
""")
|
||||
|
||||
|
||||
|
||||
@implementer(IProcess)
|
||||
class Process:
|
||||
"""
|
||||
Process running parameters.
|
||||
|
||||
Sets up uid/gid in the constructor, and has a default
|
||||
of C{None} as C{processName}.
|
||||
"""
|
||||
processName = None
|
||||
|
||||
def __init__(self, uid=None, gid=None):
|
||||
"""
|
||||
Set uid and gid.
|
||||
|
||||
@param uid: The user ID as whom to execute the process. If
|
||||
this is C{None}, no attempt will be made to change the UID.
|
||||
|
||||
@param gid: The group ID as whom to execute the process. If
|
||||
this is C{None}, no attempt will be made to change the GID.
|
||||
"""
|
||||
self.uid = uid
|
||||
self.gid = gid
|
||||
|
||||
|
||||
def Application(name, uid=None, gid=None):
|
||||
"""
|
||||
Return a compound class.
|
||||
|
||||
Return an object supporting the L{IService}, L{IServiceCollection},
|
||||
L{IProcess} and L{sob.IPersistable} interfaces, with the given
|
||||
parameters. Always access the return value by explicit casting to
|
||||
one of the interfaces.
|
||||
"""
|
||||
ret = components.Componentized()
|
||||
for comp in (MultiService(), sob.Persistent(ret, name), Process(uid, gid)):
|
||||
ret.addComponent(comp, ignoreClass=1)
|
||||
IService(ret).setName(name)
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
def loadApplication(filename, kind, passphrase=None):
|
||||
"""
|
||||
Load Application from a given file.
|
||||
|
||||
The serialization format it was saved in should be given as
|
||||
C{kind}, and is one of C{pickle}, C{source}, C{xml} or C{python}. If
|
||||
C{passphrase} is given, the application was encrypted with the
|
||||
given passphrase.
|
||||
|
||||
@type filename: C{str}
|
||||
@type kind: C{str}
|
||||
@type passphrase: C{str}
|
||||
"""
|
||||
if kind == 'python':
|
||||
application = sob.loadValueFromFile(filename, 'application', passphrase)
|
||||
else:
|
||||
application = sob.load(filename, kind, passphrase)
|
||||
return application
|
||||
|
||||
|
||||
__all__ = ['IServiceMaker', 'IService', 'Service',
|
||||
'IServiceCollection', 'MultiService',
|
||||
'IProcess', 'Process', 'Application', 'loadApplication']
|
@ -0,0 +1,103 @@
|
||||
# -*- test-case-name: twisted.test.test_strports -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Construct listening port services from a simple string description.
|
||||
|
||||
@see: L{twisted.internet.endpoints.serverFromString}
|
||||
@see: L{twisted.internet.endpoints.clientFromString}
|
||||
"""
|
||||
|
||||
import warnings
|
||||
|
||||
from twisted.internet import endpoints
|
||||
from twisted.python.deprecate import deprecatedModuleAttribute
|
||||
from twisted.python.versions import Version
|
||||
from twisted.application.internet import StreamServerEndpointService
|
||||
|
||||
|
||||
|
||||
def parse(description, factory, default='tcp'):
|
||||
"""
|
||||
This function is deprecated as of Twisted 10.2.
|
||||
|
||||
@see: L{twisted.internet.endpoints.server}
|
||||
"""
|
||||
return endpoints._parseServer(description, factory, default)
|
||||
|
||||
deprecatedModuleAttribute(
|
||||
Version("Twisted", 10, 2, 0),
|
||||
"in favor of twisted.internet.endpoints.serverFromString",
|
||||
__name__, "parse")
|
||||
|
||||
|
||||
|
||||
_DEFAULT = object()
|
||||
|
||||
def service(description, factory, default=_DEFAULT, reactor=None):
|
||||
"""
|
||||
Return the service corresponding to a description.
|
||||
|
||||
@param description: The description of the listening port, in the syntax
|
||||
described by L{twisted.internet.endpoints.server}.
|
||||
|
||||
@type description: C{str}
|
||||
|
||||
@param factory: The protocol factory which will build protocols for
|
||||
connections to this service.
|
||||
|
||||
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
|
||||
|
||||
@type default: C{str} or C{None}
|
||||
|
||||
@param default: Do not use this parameter. It has been deprecated since
|
||||
Twisted 10.2.0.
|
||||
|
||||
@rtype: C{twisted.application.service.IService}
|
||||
|
||||
@return: the service corresponding to a description of a reliable
|
||||
stream server.
|
||||
|
||||
@see: L{twisted.internet.endpoints.serverFromString}
|
||||
"""
|
||||
if reactor is None:
|
||||
from twisted.internet import reactor
|
||||
if default is _DEFAULT:
|
||||
default = None
|
||||
else:
|
||||
message = "The 'default' parameter was deprecated in Twisted 10.2.0."
|
||||
if default is not None:
|
||||
message += (
|
||||
" Use qualified endpoint descriptions; for example, "
|
||||
"'tcp:%s'." % (description,))
|
||||
warnings.warn(
|
||||
message=message, category=DeprecationWarning, stacklevel=2)
|
||||
svc = StreamServerEndpointService(
|
||||
endpoints._serverFromStringLegacy(reactor, description, default),
|
||||
factory)
|
||||
svc._raiseSynchronously = True
|
||||
return svc
|
||||
|
||||
|
||||
|
||||
def listen(description, factory, default=None):
|
||||
"""Listen on a port corresponding to a description
|
||||
|
||||
@type description: C{str}
|
||||
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
|
||||
@type default: C{str} or C{None}
|
||||
@rtype: C{twisted.internet.interfaces.IListeningPort}
|
||||
@return: the port corresponding to a description of a reliable
|
||||
virtual circuit server.
|
||||
|
||||
See the documentation of the C{parse} function for description
|
||||
of the semantics of the arguments.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
name, args, kw = parse(description, factory, default)
|
||||
return getattr(reactor, 'listen'+name)(*args, **kw)
|
||||
|
||||
|
||||
|
||||
__all__ = ['parse', 'service', 'listen']
|
@ -0,0 +1,6 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.internet.application}.
|
||||
"""
|
@ -0,0 +1,403 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for (new code in) L{twisted.application.internet}.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
|
||||
from zope.interface import implements
|
||||
from zope.interface.verify import verifyClass
|
||||
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.application import internet
|
||||
from twisted.application.internet import (
|
||||
StreamServerEndpointService, TimerService)
|
||||
from twisted.internet.interfaces import IStreamServerEndpoint, IListeningPort
|
||||
from twisted.internet.defer import Deferred, CancelledError
|
||||
from twisted.internet import task
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
|
||||
def fakeTargetFunction():
|
||||
"""
|
||||
A fake target function for testing TimerService which does nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class FakeServer(object):
|
||||
"""
|
||||
In-memory implementation of L{IStreamServerEndpoint}.
|
||||
|
||||
@ivar result: The L{Deferred} resulting from the call to C{listen}, after
|
||||
C{listen} has been called.
|
||||
|
||||
@ivar factory: The factory passed to C{listen}.
|
||||
|
||||
@ivar cancelException: The exception to errback C{self.result} when it is
|
||||
cancelled.
|
||||
|
||||
@ivar port: The L{IListeningPort} which C{listen}'s L{Deferred} will fire
|
||||
with.
|
||||
|
||||
@ivar listenAttempts: The number of times C{listen} has been invoked.
|
||||
|
||||
@ivar failImmediately: If set, the exception to fail the L{Deferred}
|
||||
returned from C{listen} before it is returned.
|
||||
"""
|
||||
|
||||
implements(IStreamServerEndpoint)
|
||||
|
||||
result = None
|
||||
factory = None
|
||||
failImmediately = None
|
||||
cancelException = CancelledError()
|
||||
listenAttempts = 0
|
||||
|
||||
def __init__(self):
|
||||
self.port = FakePort()
|
||||
|
||||
|
||||
def listen(self, factory):
|
||||
"""
|
||||
Return a Deferred and store it for future use. (Implementation of
|
||||
L{IStreamServerEndpoint}).
|
||||
"""
|
||||
self.listenAttempts += 1
|
||||
self.factory = factory
|
||||
self.result = Deferred(
|
||||
canceller=lambda d: d.errback(self.cancelException))
|
||||
if self.failImmediately is not None:
|
||||
self.result.errback(self.failImmediately)
|
||||
return self.result
|
||||
|
||||
|
||||
def startedListening(self):
|
||||
"""
|
||||
Test code should invoke this method after causing C{listen} to be
|
||||
invoked in order to fire the L{Deferred} previously returned from
|
||||
C{listen}.
|
||||
"""
|
||||
self.result.callback(self.port)
|
||||
|
||||
|
||||
def stoppedListening(self):
|
||||
"""
|
||||
Test code should invoke this method after causing C{stopListening} to
|
||||
be invoked on the port fired from the L{Deferred} returned from
|
||||
C{listen} in order to cause the L{Deferred} returned from
|
||||
C{stopListening} to fire.
|
||||
"""
|
||||
self.port.deferred.callback(None)
|
||||
|
||||
verifyClass(IStreamServerEndpoint, FakeServer)
|
||||
|
||||
|
||||
|
||||
class FakePort(object):
|
||||
"""
|
||||
Fake L{IListeningPort} implementation.
|
||||
|
||||
@ivar deferred: The L{Deferred} returned by C{stopListening}.
|
||||
"""
|
||||
|
||||
implements(IListeningPort)
|
||||
|
||||
deferred = None
|
||||
|
||||
def stopListening(self):
|
||||
self.deferred = Deferred()
|
||||
return self.deferred
|
||||
|
||||
verifyClass(IStreamServerEndpoint, FakeServer)
|
||||
|
||||
|
||||
|
||||
class EndpointServiceTests(TestCase):
|
||||
"""
|
||||
Tests for L{twisted.application.internet}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Construct a stub server, a stub factory, and a
|
||||
L{StreamServerEndpointService} to test.
|
||||
"""
|
||||
self.fakeServer = FakeServer()
|
||||
self.factory = Factory()
|
||||
self.svc = StreamServerEndpointService(self.fakeServer, self.factory)
|
||||
|
||||
|
||||
def test_privilegedStartService(self):
|
||||
"""
|
||||
L{StreamServerEndpointService.privilegedStartService} calls its
|
||||
endpoint's C{listen} method with its factory.
|
||||
"""
|
||||
self.svc.privilegedStartService()
|
||||
self.assertIdentical(self.factory, self.fakeServer.factory)
|
||||
|
||||
|
||||
def test_synchronousRaiseRaisesSynchronously(self, thunk=None):
|
||||
"""
|
||||
L{StreamServerEndpointService.startService} should raise synchronously
|
||||
if the L{Deferred} returned by its wrapped
|
||||
L{IStreamServerEndpoint.listen} has already fired with an errback and
|
||||
the L{StreamServerEndpointService}'s C{_raiseSynchronously} flag has
|
||||
been set. This feature is necessary to preserve compatibility with old
|
||||
behavior of L{twisted.internet.strports.service}, which is to return a
|
||||
service which synchronously raises an exception from C{startService}
|
||||
(so that, among other things, twistd will not start running). However,
|
||||
since L{IStreamServerEndpoint.listen} may fail asynchronously, it is
|
||||
a bad idea to rely on this behavior.
|
||||
"""
|
||||
self.fakeServer.failImmediately = ZeroDivisionError()
|
||||
self.svc._raiseSynchronously = True
|
||||
self.assertRaises(ZeroDivisionError, thunk or self.svc.startService)
|
||||
|
||||
|
||||
def test_synchronousRaisePrivileged(self):
|
||||
"""
|
||||
L{StreamServerEndpointService.privilegedStartService} should behave the
|
||||
same as C{startService} with respect to
|
||||
L{EndpointServiceTests.test_synchronousRaiseRaisesSynchronously}.
|
||||
"""
|
||||
self.test_synchronousRaiseRaisesSynchronously(
|
||||
self.svc.privilegedStartService)
|
||||
|
||||
|
||||
def test_failReportsError(self):
|
||||
"""
|
||||
L{StreamServerEndpointService.startService} and
|
||||
L{StreamServerEndpointService.privilegedStartService} should both log
|
||||
an exception when the L{Deferred} returned from their wrapped
|
||||
L{IStreamServerEndpoint.listen} fails.
|
||||
"""
|
||||
self.svc.startService()
|
||||
self.fakeServer.result.errback(ZeroDivisionError())
|
||||
logged = self.flushLoggedErrors(ZeroDivisionError)
|
||||
self.assertEqual(len(logged), 1)
|
||||
|
||||
|
||||
def test_synchronousFailReportsError(self):
|
||||
"""
|
||||
Without the C{_raiseSynchronously} compatibility flag, failing
|
||||
immediately has the same behavior as failing later; it logs the error.
|
||||
"""
|
||||
self.fakeServer.failImmediately = ZeroDivisionError()
|
||||
self.svc.startService()
|
||||
logged = self.flushLoggedErrors(ZeroDivisionError)
|
||||
self.assertEqual(len(logged), 1)
|
||||
|
||||
|
||||
def test_startServiceUnstarted(self):
|
||||
"""
|
||||
L{StreamServerEndpointService.startService} sets the C{running} flag,
|
||||
and calls its endpoint's C{listen} method with its factory, if it
|
||||
has not yet been started.
|
||||
"""
|
||||
self.svc.startService()
|
||||
self.assertIdentical(self.factory, self.fakeServer.factory)
|
||||
self.assertEqual(self.svc.running, True)
|
||||
|
||||
|
||||
def test_startServiceStarted(self):
|
||||
"""
|
||||
L{StreamServerEndpointService.startService} sets the C{running} flag,
|
||||
but nothing else, if the service has already been started.
|
||||
"""
|
||||
self.test_privilegedStartService()
|
||||
self.svc.startService()
|
||||
self.assertEqual(self.fakeServer.listenAttempts, 1)
|
||||
self.assertEqual(self.svc.running, True)
|
||||
|
||||
|
||||
def test_stopService(self):
|
||||
"""
|
||||
L{StreamServerEndpointService.stopService} calls C{stopListening} on
|
||||
the L{IListeningPort} returned from its endpoint, returns the
|
||||
C{Deferred} from stopService, and sets C{running} to C{False}.
|
||||
"""
|
||||
self.svc.privilegedStartService()
|
||||
self.fakeServer.startedListening()
|
||||
# Ensure running gets set to true
|
||||
self.svc.startService()
|
||||
result = self.svc.stopService()
|
||||
l = []
|
||||
result.addCallback(l.append)
|
||||
self.assertEqual(len(l), 0)
|
||||
self.fakeServer.stoppedListening()
|
||||
self.assertEqual(len(l), 1)
|
||||
self.assertFalse(self.svc.running)
|
||||
|
||||
|
||||
def test_stopServiceBeforeStartFinished(self):
|
||||
"""
|
||||
L{StreamServerEndpointService.stopService} cancels the L{Deferred}
|
||||
returned by C{listen} if it has not yet fired. No error will be logged
|
||||
about the cancellation of the listen attempt.
|
||||
"""
|
||||
self.svc.privilegedStartService()
|
||||
result = self.svc.stopService()
|
||||
l = []
|
||||
result.addBoth(l.append)
|
||||
self.assertEqual(l, [None])
|
||||
self.assertEqual(self.flushLoggedErrors(CancelledError), [])
|
||||
|
||||
|
||||
def test_stopServiceCancelStartError(self):
|
||||
"""
|
||||
L{StreamServerEndpointService.stopService} cancels the L{Deferred}
|
||||
returned by C{listen} if it has not fired yet. An error will be logged
|
||||
if the resulting exception is not L{CancelledError}.
|
||||
"""
|
||||
self.fakeServer.cancelException = ZeroDivisionError()
|
||||
self.svc.privilegedStartService()
|
||||
result = self.svc.stopService()
|
||||
l = []
|
||||
result.addCallback(l.append)
|
||||
self.assertEqual(l, [None])
|
||||
stoppingErrors = self.flushLoggedErrors(ZeroDivisionError)
|
||||
self.assertEqual(len(stoppingErrors), 1)
|
||||
|
||||
|
||||
|
||||
class TimerServiceTests(TestCase):
|
||||
"""
|
||||
Tests for L{twisted.application.internet.TimerService}.
|
||||
|
||||
@type timer: L{TimerService}
|
||||
@ivar timer: service to test
|
||||
|
||||
@type clock: L{task.Clock}
|
||||
@ivar clock: source of time
|
||||
|
||||
@type deferred: L{Deferred}
|
||||
@ivar deferred: deferred returned by L{TimerServiceTests.call}.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.timer = TimerService(2, self.call)
|
||||
self.clock = self.timer.clock = task.Clock()
|
||||
self.deferred = Deferred()
|
||||
|
||||
|
||||
def call(self):
|
||||
"""
|
||||
Function called by L{TimerService} being tested.
|
||||
|
||||
@returns: C{self.deferred}
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self.deferred
|
||||
|
||||
|
||||
def test_startService(self):
|
||||
"""
|
||||
When L{TimerService.startService} is called, it marks itself
|
||||
as running, creates a L{task.LoopingCall} and starts it.
|
||||
"""
|
||||
self.timer.startService()
|
||||
self.assertTrue(self.timer.running, "Service is started")
|
||||
self.assertIsInstance(self.timer._loop, task.LoopingCall)
|
||||
self.assertIdentical(self.clock, self.timer._loop.clock)
|
||||
self.assertTrue(self.timer._loop.running, "LoopingCall is started")
|
||||
|
||||
|
||||
def test_startServiceRunsCallImmediately(self):
|
||||
"""
|
||||
When L{TimerService.startService} is called, it calls the function
|
||||
immediately.
|
||||
"""
|
||||
result = []
|
||||
self.timer.call = (result.append, (None,), {})
|
||||
self.timer.startService()
|
||||
self.assertEqual([None], result)
|
||||
|
||||
|
||||
def test_startServiceUsesGlobalReactor(self):
|
||||
"""
|
||||
L{TimerService.startService} uses L{internet._maybeGlobalReactor} to
|
||||
choose the reactor to pass to L{task.LoopingCall}
|
||||
uses the global reactor.
|
||||
"""
|
||||
otherClock = task.Clock()
|
||||
def getOtherClock(maybeReactor):
|
||||
return otherClock
|
||||
self.patch(internet, "_maybeGlobalReactor", getOtherClock)
|
||||
self.timer.startService()
|
||||
self.assertIdentical(otherClock, self.timer._loop.clock)
|
||||
|
||||
|
||||
def test_stopServiceWaits(self):
|
||||
"""
|
||||
When L{TimerService.stopService} is called while a call is in progress.
|
||||
the L{Deferred} returned doesn't fire until after the call finishes.
|
||||
"""
|
||||
self.timer.startService()
|
||||
d = self.timer.stopService()
|
||||
self.assertNoResult(d)
|
||||
self.assertEqual(True, self.timer.running)
|
||||
self.deferred.callback(object())
|
||||
self.assertIdentical(self.successResultOf(d), None)
|
||||
|
||||
|
||||
def test_stopServiceImmediately(self):
|
||||
"""
|
||||
When L{TimerService.stopService} is called while a call isn't in progress.
|
||||
the L{Deferred} returned has already been fired.
|
||||
"""
|
||||
self.timer.startService()
|
||||
self.deferred.callback(object())
|
||||
d = self.timer.stopService()
|
||||
self.assertIdentical(self.successResultOf(d), None)
|
||||
|
||||
|
||||
def test_failedCallLogsError(self):
|
||||
"""
|
||||
When function passed to L{TimerService} returns a deferred that errbacks,
|
||||
the exception is logged, and L{TimerService.stopService} doesn't raise an error.
|
||||
"""
|
||||
self.timer.startService()
|
||||
self.deferred.errback(Failure(ZeroDivisionError()))
|
||||
errors = self.flushLoggedErrors(ZeroDivisionError)
|
||||
self.assertEqual(1, len(errors))
|
||||
d = self.timer.stopService()
|
||||
self.assertIdentical(self.successResultOf(d), None)
|
||||
|
||||
|
||||
def test_pickleTimerServiceNotPickleLoop(self):
|
||||
"""
|
||||
When pickling L{internet.TimerService}, it won't pickle
|
||||
L{internet.TimerService._loop}.
|
||||
"""
|
||||
# We need a pickleable callable to test pickling TimerService. So we
|
||||
# can't use self.timer
|
||||
timer = TimerService(1, fakeTargetFunction)
|
||||
timer.startService()
|
||||
dumpedTimer = pickle.dumps(timer)
|
||||
timer.stopService()
|
||||
loadedTimer = pickle.loads(dumpedTimer)
|
||||
nothing = object()
|
||||
value = getattr(loadedTimer, "_loop", nothing)
|
||||
self.assertIdentical(nothing, value)
|
||||
|
||||
|
||||
def test_pickleTimerServiceNotPickleLoopFinished(self):
|
||||
"""
|
||||
When pickling L{internet.TimerService}, it won't pickle
|
||||
L{internet.TimerService._loopFinished}.
|
||||
"""
|
||||
# We need a pickleable callable to test pickling TimerService. So we
|
||||
# can't use self.timer
|
||||
timer = TimerService(1, fakeTargetFunction)
|
||||
timer.startService()
|
||||
dumpedTimer = pickle.dumps(timer)
|
||||
timer.stopService()
|
||||
loadedTimer = pickle.loads(dumpedTimer)
|
||||
nothing = object()
|
||||
value = getattr(loadedTimer, "_loopFinished", nothing)
|
||||
self.assertIdentical(nothing, value)
|
@ -0,0 +1,10 @@
|
||||
# -*- test-case-name: twisted.conch.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Twisted Conch: The Twisted Shell. Terminal emulation, SSHv2 and telnet.
|
||||
"""
|
||||
|
||||
from twisted.conch._version import version
|
||||
__version__ = version.short()
|
@ -0,0 +1,11 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
# This is an auto-generated file. Do not edit it.
|
||||
|
||||
"""
|
||||
Provides Twisted version information.
|
||||
"""
|
||||
|
||||
from twisted.python import versions
|
||||
version = versions.Version('twisted.conch', 15, 2, 1)
|
@ -0,0 +1,39 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_conch -*-
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.conch.error import ConchError
|
||||
from twisted.conch.interfaces import IConchUser
|
||||
from twisted.conch.ssh.connection import OPEN_UNKNOWN_CHANNEL_TYPE
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
@implementer(IConchUser)
|
||||
class ConchUser:
|
||||
def __init__(self):
|
||||
self.channelLookup = {}
|
||||
self.subsystemLookup = {}
|
||||
|
||||
def lookupChannel(self, channelType, windowSize, maxPacket, data):
|
||||
klass = self.channelLookup.get(channelType, None)
|
||||
if not klass:
|
||||
raise ConchError(OPEN_UNKNOWN_CHANNEL_TYPE, "unknown channel")
|
||||
else:
|
||||
return klass(remoteWindow=windowSize,
|
||||
remoteMaxPacket=maxPacket,
|
||||
data=data, avatar=self)
|
||||
|
||||
def lookupSubsystem(self, subsystem, data):
|
||||
log.msg(repr(self.subsystemLookup))
|
||||
klass = self.subsystemLookup.get(subsystem, None)
|
||||
if not klass:
|
||||
return False
|
||||
return klass(data, avatar=self)
|
||||
|
||||
def gotGlobalRequest(self, requestType, data):
|
||||
# XXX should this use method dispatch?
|
||||
requestType = requestType.replace('-', '_')
|
||||
f = getattr(self, "global_%s" % requestType, None)
|
||||
if not f:
|
||||
return 0
|
||||
return f(data)
|
@ -0,0 +1,570 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_checkers -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Provide L{ICredentialsChecker} implementations to be used in Conch protocols.
|
||||
"""
|
||||
|
||||
import base64, binascii, errno
|
||||
try:
|
||||
import pwd
|
||||
except ImportError:
|
||||
pwd = None
|
||||
else:
|
||||
import crypt
|
||||
|
||||
try:
|
||||
# Python 2.5 got spwd to interface with shadow passwords
|
||||
import spwd
|
||||
except ImportError:
|
||||
spwd = None
|
||||
try:
|
||||
import shadow
|
||||
except ImportError:
|
||||
shadow = None
|
||||
else:
|
||||
shadow = None
|
||||
|
||||
try:
|
||||
from twisted.cred import pamauth
|
||||
except ImportError:
|
||||
pamauth = None
|
||||
|
||||
from zope.interface import providedBy, implementer, Interface
|
||||
|
||||
|
||||
from twisted.conch import error
|
||||
from twisted.conch.ssh import keys
|
||||
from twisted.cred.checkers import ICredentialsChecker
|
||||
from twisted.cred.credentials import IUsernamePassword, ISSHPrivateKey
|
||||
from twisted.cred.error import UnauthorizedLogin, UnhandledCredentials
|
||||
from twisted.internet import defer
|
||||
from twisted.python import failure, reflect, log
|
||||
from twisted.python.deprecate import deprecatedModuleAttribute
|
||||
from twisted.python.util import runAsEffectiveUser
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.versions import Version
|
||||
|
||||
|
||||
|
||||
def verifyCryptedPassword(crypted, pw):
|
||||
return crypt.crypt(pw, crypted) == crypted
|
||||
|
||||
|
||||
|
||||
def _pwdGetByName(username):
|
||||
"""
|
||||
Look up a user in the /etc/passwd database using the pwd module. If the
|
||||
pwd module is not available, return None.
|
||||
|
||||
@param username: the username of the user to return the passwd database
|
||||
information for.
|
||||
"""
|
||||
if pwd is None:
|
||||
return None
|
||||
return pwd.getpwnam(username)
|
||||
|
||||
|
||||
|
||||
def _shadowGetByName(username):
|
||||
"""
|
||||
Look up a user in the /etc/shadow database using the spwd or shadow
|
||||
modules. If neither module is available, return None.
|
||||
|
||||
@param username: the username of the user to return the shadow database
|
||||
information for.
|
||||
"""
|
||||
if spwd is not None:
|
||||
f = spwd.getspnam
|
||||
elif shadow is not None:
|
||||
f = shadow.getspnam
|
||||
else:
|
||||
return None
|
||||
return runAsEffectiveUser(0, 0, f, username)
|
||||
|
||||
|
||||
|
||||
@implementer(ICredentialsChecker)
|
||||
class UNIXPasswordDatabase:
|
||||
"""
|
||||
A checker which validates users out of the UNIX password databases, or
|
||||
databases of a compatible format.
|
||||
|
||||
@ivar _getByNameFunctions: a C{list} of functions which are called in order
|
||||
to valid a user. The default value is such that the /etc/passwd
|
||||
database will be tried first, followed by the /etc/shadow database.
|
||||
"""
|
||||
credentialInterfaces = IUsernamePassword,
|
||||
|
||||
def __init__(self, getByNameFunctions=None):
|
||||
if getByNameFunctions is None:
|
||||
getByNameFunctions = [_pwdGetByName, _shadowGetByName]
|
||||
self._getByNameFunctions = getByNameFunctions
|
||||
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
for func in self._getByNameFunctions:
|
||||
try:
|
||||
pwnam = func(credentials.username)
|
||||
except KeyError:
|
||||
return defer.fail(UnauthorizedLogin("invalid username"))
|
||||
else:
|
||||
if pwnam is not None:
|
||||
crypted = pwnam[1]
|
||||
if crypted == '':
|
||||
continue
|
||||
if verifyCryptedPassword(crypted, credentials.password):
|
||||
return defer.succeed(credentials.username)
|
||||
# fallback
|
||||
return defer.fail(UnauthorizedLogin("unable to verify password"))
|
||||
|
||||
|
||||
|
||||
@implementer(ICredentialsChecker)
|
||||
class SSHPublicKeyDatabase:
|
||||
"""
|
||||
Checker that authenticates SSH public keys, based on public keys listed in
|
||||
authorized_keys and authorized_keys2 files in user .ssh/ directories.
|
||||
"""
|
||||
credentialInterfaces = (ISSHPrivateKey,)
|
||||
|
||||
_userdb = pwd
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
d = defer.maybeDeferred(self.checkKey, credentials)
|
||||
d.addCallback(self._cbRequestAvatarId, credentials)
|
||||
d.addErrback(self._ebRequestAvatarId)
|
||||
return d
|
||||
|
||||
def _cbRequestAvatarId(self, validKey, credentials):
|
||||
"""
|
||||
Check whether the credentials themselves are valid, now that we know
|
||||
if the key matches the user.
|
||||
|
||||
@param validKey: A boolean indicating whether or not the public key
|
||||
matches a key in the user's authorized_keys file.
|
||||
|
||||
@param credentials: The credentials offered by the user.
|
||||
@type credentials: L{ISSHPrivateKey} provider
|
||||
|
||||
@raise UnauthorizedLogin: (as a failure) if the key does not match the
|
||||
user in C{credentials}. Also raised if the user provides an invalid
|
||||
signature.
|
||||
|
||||
@raise ValidPublicKey: (as a failure) if the key matches the user but
|
||||
the credentials do not include a signature. See
|
||||
L{error.ValidPublicKey} for more information.
|
||||
|
||||
@return: The user's username, if authentication was successful.
|
||||
"""
|
||||
if not validKey:
|
||||
return failure.Failure(UnauthorizedLogin("invalid key"))
|
||||
if not credentials.signature:
|
||||
return failure.Failure(error.ValidPublicKey())
|
||||
else:
|
||||
try:
|
||||
pubKey = keys.Key.fromString(credentials.blob)
|
||||
if pubKey.verify(credentials.signature, credentials.sigData):
|
||||
return credentials.username
|
||||
except: # any error should be treated as a failed login
|
||||
log.err()
|
||||
return failure.Failure(UnauthorizedLogin('error while verifying key'))
|
||||
return failure.Failure(UnauthorizedLogin("unable to verify key"))
|
||||
|
||||
|
||||
def getAuthorizedKeysFiles(self, credentials):
|
||||
"""
|
||||
Return a list of L{FilePath} instances for I{authorized_keys} files
|
||||
which might contain information about authorized keys for the given
|
||||
credentials.
|
||||
|
||||
On OpenSSH servers, the default location of the file containing the
|
||||
list of authorized public keys is
|
||||
U{$HOME/.ssh/authorized_keys<http://www.openbsd.org/cgi-bin/man.cgi?query=sshd_config>}.
|
||||
|
||||
I{$HOME/.ssh/authorized_keys2} is also returned, though it has been
|
||||
U{deprecated by OpenSSH since
|
||||
2001<http://marc.info/?m=100508718416162>}.
|
||||
|
||||
@return: A list of L{FilePath} instances to files with the authorized keys.
|
||||
"""
|
||||
pwent = self._userdb.getpwnam(credentials.username)
|
||||
root = FilePath(pwent.pw_dir).child('.ssh')
|
||||
files = ['authorized_keys', 'authorized_keys2']
|
||||
return [root.child(f) for f in files]
|
||||
|
||||
|
||||
def checkKey(self, credentials):
|
||||
"""
|
||||
Retrieve files containing authorized keys and check against user
|
||||
credentials.
|
||||
"""
|
||||
ouid, ogid = self._userdb.getpwnam(credentials.username)[2:4]
|
||||
for filepath in self.getAuthorizedKeysFiles(credentials):
|
||||
if not filepath.exists():
|
||||
continue
|
||||
try:
|
||||
lines = filepath.open()
|
||||
except IOError, e:
|
||||
if e.errno == errno.EACCES:
|
||||
lines = runAsEffectiveUser(ouid, ogid, filepath.open)
|
||||
else:
|
||||
raise
|
||||
for l in lines:
|
||||
l2 = l.split()
|
||||
if len(l2) < 2:
|
||||
continue
|
||||
try:
|
||||
if base64.decodestring(l2[1]) == credentials.blob:
|
||||
return True
|
||||
except binascii.Error:
|
||||
continue
|
||||
return False
|
||||
|
||||
def _ebRequestAvatarId(self, f):
|
||||
if not f.check(UnauthorizedLogin):
|
||||
log.msg(f)
|
||||
return failure.Failure(UnauthorizedLogin("unable to get avatar id"))
|
||||
return f
|
||||
|
||||
|
||||
|
||||
@implementer(ICredentialsChecker)
|
||||
class SSHProtocolChecker:
|
||||
"""
|
||||
SSHProtocolChecker is a checker that requires multiple authentications
|
||||
to succeed. To add a checker, call my registerChecker method with
|
||||
the checker and the interface.
|
||||
|
||||
After each successful authenticate, I call my areDone method with the
|
||||
avatar id. To get a list of the successful credentials for an avatar id,
|
||||
use C{SSHProcotolChecker.successfulCredentials[avatarId]}. If L{areDone}
|
||||
returns True, the authentication has succeeded.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.checkers = {}
|
||||
self.successfulCredentials = {}
|
||||
|
||||
def get_credentialInterfaces(self):
|
||||
return self.checkers.keys()
|
||||
|
||||
credentialInterfaces = property(get_credentialInterfaces)
|
||||
|
||||
def registerChecker(self, checker, *credentialInterfaces):
|
||||
if not credentialInterfaces:
|
||||
credentialInterfaces = checker.credentialInterfaces
|
||||
for credentialInterface in credentialInterfaces:
|
||||
self.checkers[credentialInterface] = checker
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
"""
|
||||
Part of the L{ICredentialsChecker} interface. Called by a portal with
|
||||
some credentials to check if they'll authenticate a user. We check the
|
||||
interfaces that the credentials provide against our list of acceptable
|
||||
checkers. If one of them matches, we ask that checker to verify the
|
||||
credentials. If they're valid, we call our L{_cbGoodAuthentication}
|
||||
method to continue.
|
||||
|
||||
@param credentials: the credentials the L{Portal} wants us to verify
|
||||
"""
|
||||
ifac = providedBy(credentials)
|
||||
for i in ifac:
|
||||
c = self.checkers.get(i)
|
||||
if c is not None:
|
||||
d = defer.maybeDeferred(c.requestAvatarId, credentials)
|
||||
return d.addCallback(self._cbGoodAuthentication,
|
||||
credentials)
|
||||
return defer.fail(UnhandledCredentials("No checker for %s" % \
|
||||
', '.join(map(reflect.qual, ifac))))
|
||||
|
||||
def _cbGoodAuthentication(self, avatarId, credentials):
|
||||
"""
|
||||
Called if a checker has verified the credentials. We call our
|
||||
L{areDone} method to see if the whole of the successful authentications
|
||||
are enough. If they are, we return the avatar ID returned by the first
|
||||
checker.
|
||||
"""
|
||||
if avatarId not in self.successfulCredentials:
|
||||
self.successfulCredentials[avatarId] = []
|
||||
self.successfulCredentials[avatarId].append(credentials)
|
||||
if self.areDone(avatarId):
|
||||
del self.successfulCredentials[avatarId]
|
||||
return avatarId
|
||||
else:
|
||||
raise error.NotEnoughAuthentication()
|
||||
|
||||
def areDone(self, avatarId):
|
||||
"""
|
||||
Override to determine if the authentication is finished for a given
|
||||
avatarId.
|
||||
|
||||
@param avatarId: the avatar returned by the first checker. For
|
||||
this checker to function correctly, all the checkers must
|
||||
return the same avatar ID.
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
|
||||
deprecatedModuleAttribute(
|
||||
Version("Twisted", 15, 0, 0),
|
||||
("Please use twisted.conch.checkers.SSHPublicKeyChecker, "
|
||||
"initialized with an instance of "
|
||||
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead."),
|
||||
__name__, "SSHPublicKeyDatabase")
|
||||
|
||||
|
||||
|
||||
class IAuthorizedKeysDB(Interface):
|
||||
"""
|
||||
An object that provides valid authorized ssh keys mapped to usernames.
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
def getAuthorizedKeys(avatarId):
|
||||
"""
|
||||
Gets an iterable of authorized keys that are valid for the given
|
||||
C{avatarId}.
|
||||
|
||||
@param avatarId: the ID of the avatar
|
||||
@type avatarId: valid return value of
|
||||
L{twisted.cred.checkers.ICredentialsChecker.requestAvatarId}
|
||||
|
||||
@return: an iterable of L{twisted.conch.ssh.keys.Key}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def readAuthorizedKeyFile(fileobj, parseKey=keys.Key.fromString):
|
||||
"""
|
||||
Reads keys from an authorized keys file. Any non-comment line that cannot
|
||||
be parsed as a key will be ignored, although that particular line will
|
||||
be logged.
|
||||
|
||||
@param fileobj: something from which to read lines which can be parsed
|
||||
as keys
|
||||
@type fileobj: L{file}-like object
|
||||
|
||||
@param parseKey: a callable that takes a string and returns a
|
||||
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
|
||||
default is L{twisted.conch.ssh.keys.Key.fromString}.
|
||||
@type parseKey: L{callable}
|
||||
|
||||
@return: an iterable of L{twisted.conch.ssh.keys.Key}
|
||||
@rtype: iterable
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
for line in fileobj:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'): # for comments
|
||||
try:
|
||||
yield parseKey(line)
|
||||
except keys.BadKeyError as e:
|
||||
log.msg('Unable to parse line "{0}" as a key: {1!s}'
|
||||
.format(line, e))
|
||||
|
||||
|
||||
|
||||
def _keysFromFilepaths(filepaths, parseKey):
|
||||
"""
|
||||
Helper function that turns an iterable of filepaths into a generator of
|
||||
keys. If any file cannot be read, a message is logged but it is
|
||||
otherwise ignored.
|
||||
|
||||
@param filepaths: iterable of L{twisted.python.filepath.FilePath}.
|
||||
@type filepaths: iterable
|
||||
|
||||
@param parseKey: a callable that takes a string and returns a
|
||||
L{twisted.conch.ssh.keys.Key}
|
||||
@type parseKey: L{callable}
|
||||
|
||||
@return: generator of L{twisted.conch.ssh.keys.Key}
|
||||
@rtype: generator
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
for fp in filepaths:
|
||||
if fp.exists():
|
||||
try:
|
||||
with fp.open() as f:
|
||||
for key in readAuthorizedKeyFile(f, parseKey):
|
||||
yield key
|
||||
except (IOError, OSError) as e:
|
||||
log.msg("Unable to read {0}: {1!s}".format(fp.path, e))
|
||||
|
||||
|
||||
|
||||
@implementer(IAuthorizedKeysDB)
|
||||
class InMemorySSHKeyDB(object):
|
||||
"""
|
||||
Object that provides SSH public keys based on a dictionary of usernames
|
||||
mapped to L{twisted.conch.ssh.keys.Key}s.
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
def __init__(self, mapping):
|
||||
"""
|
||||
Initializes a new L{InMemorySSHKeyDB}.
|
||||
|
||||
@param mapping: mapping of usernames to iterables of
|
||||
L{twisted.conch.ssh.keys.Key}s
|
||||
@type mapping: C{dict}
|
||||
|
||||
"""
|
||||
self._mapping = mapping
|
||||
|
||||
|
||||
def getAuthorizedKeys(self, username):
|
||||
return self._mapping.get(username, [])
|
||||
|
||||
|
||||
|
||||
@implementer(IAuthorizedKeysDB)
|
||||
class UNIXAuthorizedKeysFiles(object):
|
||||
"""
|
||||
Object that provides SSH public keys based on public keys listed in
|
||||
authorized_keys and authorized_keys2 files in UNIX user .ssh/ directories.
|
||||
If any of the files cannot be read, a message is logged but that file is
|
||||
otherwise ignored.
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
def __init__(self, userdb=None, parseKey=keys.Key.fromString):
|
||||
"""
|
||||
Initializes a new L{UNIXAuthorizedKeysFiles}.
|
||||
|
||||
@param userdb: access to the Unix user account and password database
|
||||
(default is the Python module L{pwd})
|
||||
@type userdb: L{pwd}-like object
|
||||
|
||||
@param parseKey: a callable that takes a string and returns a
|
||||
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
|
||||
default is L{twisted.conch.ssh.keys.Key.fromString}.
|
||||
@type parseKey: L{callable}
|
||||
"""
|
||||
self._userdb = userdb
|
||||
self._parseKey = parseKey
|
||||
if userdb is None:
|
||||
self._userdb = pwd
|
||||
|
||||
|
||||
def getAuthorizedKeys(self, username):
|
||||
try:
|
||||
passwd = self._userdb.getpwnam(username)
|
||||
except KeyError:
|
||||
return ()
|
||||
|
||||
root = FilePath(passwd.pw_dir).child('.ssh')
|
||||
files = ['authorized_keys', 'authorized_keys2']
|
||||
return _keysFromFilepaths((root.child(f) for f in files),
|
||||
self._parseKey)
|
||||
|
||||
|
||||
|
||||
@implementer(ICredentialsChecker)
|
||||
class SSHPublicKeyChecker(object):
|
||||
"""
|
||||
Checker that authenticates SSH public keys, based on public keys listed in
|
||||
authorized_keys and authorized_keys2 files in user .ssh/ directories.
|
||||
|
||||
Initializing this checker with a L{UNIXAuthorizedKeysFiles} should be
|
||||
used instead of L{twisted.conch.checkers.SSHPublicKeyDatabase}.
|
||||
|
||||
@since: 15.0
|
||||
"""
|
||||
credentialInterfaces = (ISSHPrivateKey,)
|
||||
|
||||
def __init__(self, keydb):
|
||||
"""
|
||||
Initializes a L{SSHPublicKeyChecker}.
|
||||
|
||||
@param keydb: a provider of L{IAuthorizedKeysDB}
|
||||
@type keydb: L{IAuthorizedKeysDB} provider
|
||||
"""
|
||||
self._keydb = keydb
|
||||
|
||||
|
||||
def requestAvatarId(self, credentials):
|
||||
d = defer.maybeDeferred(self._sanityCheckKey, credentials)
|
||||
d.addCallback(self._checkKey, credentials)
|
||||
d.addCallback(self._verifyKey, credentials)
|
||||
return d
|
||||
|
||||
|
||||
def _sanityCheckKey(self, credentials):
|
||||
"""
|
||||
Checks whether the provided credentials are a valid SSH key with a
|
||||
signature (does not actually verify the signature).
|
||||
|
||||
@param credentials: the credentials offered by the user
|
||||
@type credentials: L{ISSHPrivateKey} provider
|
||||
|
||||
@raise ValidPublicKey: the credentials do not include a signature. See
|
||||
L{error.ValidPublicKey} for more information.
|
||||
|
||||
@raise BadKeyError: The key included with the credentials is not
|
||||
recognized as a key.
|
||||
|
||||
@return: the key in the credentials
|
||||
@rtype: L{twisted.conch.ssh.keys.Key}
|
||||
"""
|
||||
if not credentials.signature:
|
||||
raise error.ValidPublicKey()
|
||||
|
||||
return keys.Key.fromString(credentials.blob)
|
||||
|
||||
|
||||
def _checkKey(self, pubKey, credentials):
|
||||
"""
|
||||
Checks the public key against all authorized keys (if any) for the
|
||||
user.
|
||||
|
||||
@param pubKey: the key in the credentials (just to prevent it from
|
||||
having to be calculated again)
|
||||
@type pubKey:
|
||||
|
||||
@param credentials: the credentials offered by the user
|
||||
@type credentials: L{ISSHPrivateKey} provider
|
||||
|
||||
@raise UnauthorizedLogin: If the key is not authorized, or if there
|
||||
was any error obtaining a list of authorized keys for the user.
|
||||
|
||||
@return: C{pubKey} if the key is authorized
|
||||
@rtype: L{twisted.conch.ssh.keys.Key}
|
||||
"""
|
||||
if any(key == pubKey for key in
|
||||
self._keydb.getAuthorizedKeys(credentials.username)):
|
||||
return pubKey
|
||||
|
||||
raise UnauthorizedLogin("Key not authorized")
|
||||
|
||||
|
||||
def _verifyKey(self, pubKey, credentials):
|
||||
"""
|
||||
Checks whether the credentials themselves are valid, now that we know
|
||||
if the key matches the user.
|
||||
|
||||
@param pubKey: the key in the credentials (just to prevent it from
|
||||
having to be calculated again)
|
||||
@type pubKey: L{twisted.conch.ssh.keys.Key}
|
||||
|
||||
@param credentials: the credentials offered by the user
|
||||
@type credentials: L{ISSHPrivateKey} provider
|
||||
|
||||
@raise UnauthorizedLogin: If the key signature is invalid or there
|
||||
was any error verifying the signature.
|
||||
|
||||
@return: The user's username, if authentication was successful
|
||||
@rtype: C{str}
|
||||
"""
|
||||
try:
|
||||
if pubKey.verify(credentials.signature, credentials.sigData):
|
||||
return credentials.username
|
||||
except: # any error should be treated as a failed login
|
||||
log.err()
|
||||
raise UnauthorizedLogin('Error while verifying key')
|
||||
|
||||
raise UnauthorizedLogin("Key signature invalid.")
|
@ -0,0 +1,9 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
"""
|
||||
Client support code for Conch.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
@ -0,0 +1,73 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_default -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Accesses the key agent for user authentication.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from twisted.conch.ssh import agent, channel, keys
|
||||
from twisted.internet import protocol, reactor
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
|
||||
class SSHAgentClient(agent.SSHAgentClient):
|
||||
|
||||
def __init__(self):
|
||||
agent.SSHAgentClient.__init__(self)
|
||||
self.blobs = []
|
||||
|
||||
|
||||
def getPublicKeys(self):
|
||||
return self.requestIdentities().addCallback(self._cbPublicKeys)
|
||||
|
||||
|
||||
def _cbPublicKeys(self, blobcomm):
|
||||
log.msg('got %i public keys' % len(blobcomm))
|
||||
self.blobs = [x[0] for x in blobcomm]
|
||||
|
||||
|
||||
def getPublicKey(self):
|
||||
"""
|
||||
Return a L{Key} from the first blob in C{self.blobs}, if any, or
|
||||
return C{None}.
|
||||
"""
|
||||
if self.blobs:
|
||||
return keys.Key.fromString(self.blobs.pop(0))
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class SSHAgentForwardingChannel(channel.SSHChannel):
|
||||
|
||||
def channelOpen(self, specificData):
|
||||
cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal)
|
||||
d = cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
|
||||
d.addCallback(self._cbGotLocal)
|
||||
d.addErrback(lambda x:self.loseConnection())
|
||||
self.buf = ''
|
||||
|
||||
|
||||
def _cbGotLocal(self, local):
|
||||
self.local = local
|
||||
self.dataReceived = self.local.transport.write
|
||||
self.local.dataReceived = self.write
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buf += data
|
||||
|
||||
|
||||
def closed(self):
|
||||
if self.local:
|
||||
self.local.loseConnection()
|
||||
self.local = None
|
||||
|
||||
|
||||
class SSHAgentForwardingLocal(protocol.Protocol):
|
||||
pass
|
@ -0,0 +1,21 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
import direct
|
||||
|
||||
connectTypes = {"direct" : direct.connect}
|
||||
|
||||
def connect(host, port, options, verifyHostKey, userAuthObject):
|
||||
useConnects = ['direct']
|
||||
return _ebConnect(None, useConnects, host, port, options, verifyHostKey,
|
||||
userAuthObject)
|
||||
|
||||
def _ebConnect(f, useConnects, host, port, options, vhk, uao):
|
||||
if not useConnects:
|
||||
return f
|
||||
connectType = useConnects.pop(0)
|
||||
f = connectTypes[connectType]
|
||||
d = f(host, port, options, vhk, uao)
|
||||
d.addErrback(_ebConnect, useConnects, host, port, options, vhk, uao)
|
||||
return d
|
@ -0,0 +1,260 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_knownhosts,twisted.conch.test.test_default -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Various classes and functions for implementing user-interaction in the
|
||||
command-line conch client.
|
||||
|
||||
You probably shouldn't use anything in this module directly, since it assumes
|
||||
you are sitting at an interactive terminal. For example, to programmatically
|
||||
interact with a known_hosts database, use L{twisted.conch.client.knownhosts}.
|
||||
"""
|
||||
|
||||
from twisted.python import log
|
||||
from twisted.python.filepath import FilePath
|
||||
|
||||
from twisted.conch.error import ConchError
|
||||
from twisted.conch.ssh import common, keys, userauth
|
||||
from twisted.internet import defer, protocol, reactor
|
||||
|
||||
from twisted.conch.client.knownhosts import KnownHostsFile, ConsoleUI
|
||||
|
||||
from twisted.conch.client import agent
|
||||
|
||||
import os, sys, base64, getpass
|
||||
|
||||
# The default location of the known hosts file (probably should be parsed out
|
||||
# of an ssh config file someday).
|
||||
_KNOWN_HOSTS = "~/.ssh/known_hosts"
|
||||
|
||||
|
||||
# This name is bound so that the unit tests can use 'patch' to override it.
|
||||
_open = open
|
||||
|
||||
def verifyHostKey(transport, host, pubKey, fingerprint):
|
||||
"""
|
||||
Verify a host's key.
|
||||
|
||||
This function is a gross vestige of some bad factoring in the client
|
||||
internals. The actual implementation, and a better signature of this logic
|
||||
is in L{KnownHostsFile.verifyHostKey}. This function is not deprecated yet
|
||||
because the callers have not yet been rehabilitated, but they should
|
||||
eventually be changed to call that method instead.
|
||||
|
||||
However, this function does perform two functions not implemented by
|
||||
L{KnownHostsFile.verifyHostKey}. It determines the path to the user's
|
||||
known_hosts file based on the options (which should really be the options
|
||||
object's job), and it provides an opener to L{ConsoleUI} which opens
|
||||
'/dev/tty' so that the user will be prompted on the tty of the process even
|
||||
if the input and output of the process has been redirected. This latter
|
||||
part is, somewhat obviously, not portable, but I don't know of a portable
|
||||
equivalent that could be used.
|
||||
|
||||
@param host: Due to a bug in L{SSHClientTransport.verifyHostKey}, this is
|
||||
always the dotted-quad IP address of the host being connected to.
|
||||
@type host: L{str}
|
||||
|
||||
@param transport: the client transport which is attempting to connect to
|
||||
the given host.
|
||||
@type transport: L{SSHClientTransport}
|
||||
|
||||
@param fingerprint: the fingerprint of the given public key, in
|
||||
xx:xx:xx:... format. This is ignored in favor of getting the fingerprint
|
||||
from the key itself.
|
||||
@type fingerprint: L{str}
|
||||
|
||||
@param pubKey: The public key of the server being connected to.
|
||||
@type pubKey: L{str}
|
||||
|
||||
@return: a L{Deferred} which fires with C{1} if the key was successfully
|
||||
verified, or fails if the key could not be successfully verified. Failure
|
||||
types may include L{HostKeyChanged}, L{UserRejectedKey}, L{IOError} or
|
||||
L{KeyboardInterrupt}.
|
||||
"""
|
||||
actualHost = transport.factory.options['host']
|
||||
actualKey = keys.Key.fromString(pubKey)
|
||||
kh = KnownHostsFile.fromPath(FilePath(
|
||||
transport.factory.options['known-hosts']
|
||||
or os.path.expanduser(_KNOWN_HOSTS)
|
||||
))
|
||||
ui = ConsoleUI(lambda : _open("/dev/tty", "r+b"))
|
||||
return kh.verifyHostKey(ui, actualHost, host, actualKey)
|
||||
|
||||
|
||||
def isInKnownHosts(host, pubKey, options):
|
||||
"""checks to see if host is in the known_hosts file for the user.
|
||||
returns 0 if it isn't, 1 if it is and is the same, 2 if it's changed.
|
||||
"""
|
||||
keyType = common.getNS(pubKey)[0]
|
||||
retVal = 0
|
||||
|
||||
if not options['known-hosts'] and not os.path.exists(os.path.expanduser('~/.ssh/')):
|
||||
print 'Creating ~/.ssh directory...'
|
||||
os.mkdir(os.path.expanduser('~/.ssh'))
|
||||
kh_file = options['known-hosts'] or _KNOWN_HOSTS
|
||||
try:
|
||||
known_hosts = open(os.path.expanduser(kh_file))
|
||||
except IOError:
|
||||
return 0
|
||||
for line in known_hosts.xreadlines():
|
||||
split = line.split()
|
||||
if len(split) < 3:
|
||||
continue
|
||||
hosts, hostKeyType, encodedKey = split[:3]
|
||||
if host not in hosts.split(','): # incorrect host
|
||||
continue
|
||||
if hostKeyType != keyType: # incorrect type of key
|
||||
continue
|
||||
try:
|
||||
decodedKey = base64.decodestring(encodedKey)
|
||||
except:
|
||||
continue
|
||||
if decodedKey == pubKey:
|
||||
return 1
|
||||
else:
|
||||
retVal = 2
|
||||
return retVal
|
||||
|
||||
|
||||
|
||||
class SSHUserAuthClient(userauth.SSHUserAuthClient):
|
||||
|
||||
def __init__(self, user, options, *args):
|
||||
userauth.SSHUserAuthClient.__init__(self, user, *args)
|
||||
self.keyAgent = None
|
||||
self.options = options
|
||||
self.usedFiles = []
|
||||
if not options.identitys:
|
||||
options.identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
|
||||
|
||||
def serviceStarted(self):
|
||||
if 'SSH_AUTH_SOCK' in os.environ and not self.options['noagent']:
|
||||
log.msg('using agent')
|
||||
cc = protocol.ClientCreator(reactor, agent.SSHAgentClient)
|
||||
d = cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
|
||||
d.addCallback(self._setAgent)
|
||||
d.addErrback(self._ebSetAgent)
|
||||
else:
|
||||
userauth.SSHUserAuthClient.serviceStarted(self)
|
||||
|
||||
def serviceStopped(self):
|
||||
if self.keyAgent:
|
||||
self.keyAgent.transport.loseConnection()
|
||||
self.keyAgent = None
|
||||
|
||||
def _setAgent(self, a):
|
||||
self.keyAgent = a
|
||||
d = self.keyAgent.getPublicKeys()
|
||||
d.addBoth(self._ebSetAgent)
|
||||
return d
|
||||
|
||||
def _ebSetAgent(self, f):
|
||||
userauth.SSHUserAuthClient.serviceStarted(self)
|
||||
|
||||
def _getPassword(self, prompt):
|
||||
try:
|
||||
oldout, oldin = sys.stdout, sys.stdin
|
||||
sys.stdin = sys.stdout = open('/dev/tty','r+')
|
||||
p=getpass.getpass(prompt)
|
||||
sys.stdout,sys.stdin=oldout,oldin
|
||||
return p
|
||||
except (KeyboardInterrupt, IOError):
|
||||
print
|
||||
raise ConchError('PEBKAC')
|
||||
|
||||
def getPassword(self, prompt = None):
|
||||
if not prompt:
|
||||
prompt = "%s@%s's password: " % (self.user, self.transport.transport.getPeer().host)
|
||||
try:
|
||||
p = self._getPassword(prompt)
|
||||
return defer.succeed(p)
|
||||
except ConchError:
|
||||
return defer.fail()
|
||||
|
||||
|
||||
def getPublicKey(self):
|
||||
"""
|
||||
Get a public key from the key agent if possible, otherwise look in
|
||||
the next configured identity file for one.
|
||||
"""
|
||||
if self.keyAgent:
|
||||
key = self.keyAgent.getPublicKey()
|
||||
if key is not None:
|
||||
return key
|
||||
files = [x for x in self.options.identitys if x not in self.usedFiles]
|
||||
log.msg(str(self.options.identitys))
|
||||
log.msg(str(files))
|
||||
if not files:
|
||||
return None
|
||||
file = files[0]
|
||||
log.msg(file)
|
||||
self.usedFiles.append(file)
|
||||
file = os.path.expanduser(file)
|
||||
file += '.pub'
|
||||
if not os.path.exists(file):
|
||||
return self.getPublicKey() # try again
|
||||
try:
|
||||
return keys.Key.fromFile(file)
|
||||
except keys.BadKeyError:
|
||||
return self.getPublicKey() # try again
|
||||
|
||||
|
||||
def signData(self, publicKey, signData):
|
||||
"""
|
||||
Extend the base signing behavior by using an SSH agent to sign the
|
||||
data, if one is available.
|
||||
|
||||
@type publicKey: L{Key}
|
||||
@type signData: C{str}
|
||||
"""
|
||||
if not self.usedFiles: # agent key
|
||||
return self.keyAgent.signData(publicKey.blob(), signData)
|
||||
else:
|
||||
return userauth.SSHUserAuthClient.signData(self, publicKey, signData)
|
||||
|
||||
|
||||
def getPrivateKey(self):
|
||||
"""
|
||||
Try to load the private key from the last used file identified by
|
||||
C{getPublicKey}, potentially asking for the passphrase if the key is
|
||||
encrypted.
|
||||
"""
|
||||
file = os.path.expanduser(self.usedFiles[-1])
|
||||
if not os.path.exists(file):
|
||||
return None
|
||||
try:
|
||||
return defer.succeed(keys.Key.fromFile(file))
|
||||
except keys.EncryptedKeyError:
|
||||
for i in range(3):
|
||||
prompt = "Enter passphrase for key '%s': " % \
|
||||
self.usedFiles[-1]
|
||||
try:
|
||||
p = self._getPassword(prompt)
|
||||
return defer.succeed(keys.Key.fromFile(file, passphrase=p))
|
||||
except (keys.BadKeyError, ConchError):
|
||||
pass
|
||||
return defer.fail(ConchError('bad password'))
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
print
|
||||
reactor.stop()
|
||||
|
||||
|
||||
def getGenericAnswers(self, name, instruction, prompts):
|
||||
responses = []
|
||||
try:
|
||||
oldout, oldin = sys.stdout, sys.stdin
|
||||
sys.stdin = sys.stdout = open('/dev/tty','r+')
|
||||
if name:
|
||||
print name
|
||||
if instruction:
|
||||
print instruction
|
||||
for prompt, echo in prompts:
|
||||
if echo:
|
||||
responses.append(raw_input(prompt))
|
||||
else:
|
||||
responses.append(getpass.getpass(prompt))
|
||||
finally:
|
||||
sys.stdout,sys.stdin=oldout,oldin
|
||||
return defer.succeed(responses)
|
@ -0,0 +1,107 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
from twisted.internet import defer, protocol, reactor
|
||||
from twisted.conch import error
|
||||
from twisted.conch.ssh import transport
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
|
||||
class SSHClientFactory(protocol.ClientFactory):
|
||||
|
||||
def __init__(self, d, options, verifyHostKey, userAuthObject):
|
||||
self.d = d
|
||||
self.options = options
|
||||
self.verifyHostKey = verifyHostKey
|
||||
self.userAuthObject = userAuthObject
|
||||
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
if self.options['reconnect']:
|
||||
connector.connect()
|
||||
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
if self.d is None:
|
||||
return
|
||||
d, self.d = self.d, None
|
||||
d.errback(reason)
|
||||
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
trans = SSHClientTransport(self)
|
||||
if self.options['ciphers']:
|
||||
trans.supportedCiphers = self.options['ciphers']
|
||||
if self.options['macs']:
|
||||
trans.supportedMACs = self.options['macs']
|
||||
if self.options['compress']:
|
||||
trans.supportedCompressions[0:1] = ['zlib']
|
||||
if self.options['host-key-algorithms']:
|
||||
trans.supportedPublicKeys = self.options['host-key-algorithms']
|
||||
return trans
|
||||
|
||||
|
||||
|
||||
class SSHClientTransport(transport.SSHClientTransport):
|
||||
|
||||
def __init__(self, factory):
|
||||
self.factory = factory
|
||||
self.unixServer = None
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if self.unixServer:
|
||||
d = self.unixServer.stopListening()
|
||||
self.unixServer = None
|
||||
else:
|
||||
d = defer.succeed(None)
|
||||
d.addCallback(lambda x:
|
||||
transport.SSHClientTransport.connectionLost(self, reason))
|
||||
|
||||
|
||||
def receiveError(self, code, desc):
|
||||
if self.factory.d is None:
|
||||
return
|
||||
d, self.factory.d = self.factory.d, None
|
||||
d.errback(error.ConchError(desc, code))
|
||||
|
||||
|
||||
def sendDisconnect(self, code, reason):
|
||||
if self.factory.d is None:
|
||||
return
|
||||
d, self.factory.d = self.factory.d, None
|
||||
transport.SSHClientTransport.sendDisconnect(self, code, reason)
|
||||
d.errback(error.ConchError(reason, code))
|
||||
|
||||
|
||||
def receiveDebug(self, alwaysDisplay, message, lang):
|
||||
log.msg('Received Debug Message: %s' % message)
|
||||
if alwaysDisplay: # XXX what should happen here?
|
||||
print message
|
||||
|
||||
|
||||
def verifyHostKey(self, pubKey, fingerprint):
|
||||
return self.factory.verifyHostKey(self, self.transport.getPeer().host, pubKey,
|
||||
fingerprint)
|
||||
|
||||
|
||||
def setService(self, service):
|
||||
log.msg('setting client server to %s' % service)
|
||||
transport.SSHClientTransport.setService(self, service)
|
||||
if service.name != 'ssh-userauth' and self.factory.d is not None:
|
||||
d, self.factory.d = self.factory.d, None
|
||||
d.callback(None)
|
||||
|
||||
|
||||
def connectionSecure(self):
|
||||
self.requestService(self.factory.userAuthObject)
|
||||
|
||||
|
||||
|
||||
def connect(host, port, options, verifyHostKey, userAuthObject):
|
||||
d = defer.Deferred()
|
||||
factory = SSHClientFactory(d, options, verifyHostKey, userAuthObject)
|
||||
reactor.connectTCP(host, port, factory)
|
||||
return d
|
@ -0,0 +1,621 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_knownhosts -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
An implementation of the OpenSSH known_hosts database.
|
||||
|
||||
@since: 8.2
|
||||
"""
|
||||
|
||||
import hmac
|
||||
from binascii import Error as DecodeError, b2a_base64
|
||||
from hashlib import sha1
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.python.randbytes import secureRandom
|
||||
from twisted.internet import defer
|
||||
from twisted.python import log
|
||||
from twisted.python.util import FancyEqMixin
|
||||
from twisted.conch.interfaces import IKnownHostEntry
|
||||
from twisted.conch.error import HostKeyChanged, UserRejectedKey, InvalidEntry
|
||||
from twisted.conch.ssh.keys import Key, BadKeyError
|
||||
|
||||
|
||||
def _b64encode(s):
|
||||
"""
|
||||
Encode a binary string as base64 with no trailing newline.
|
||||
|
||||
@param s: The string to encode.
|
||||
@type s: L{bytes}
|
||||
|
||||
@return: The base64-encoded string.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return b2a_base64(s).strip()
|
||||
|
||||
|
||||
|
||||
def _extractCommon(string):
|
||||
"""
|
||||
Extract common elements of base64 keys from an entry in a hosts file.
|
||||
|
||||
@param string: A known hosts file entry (a single line).
|
||||
@type string: L{bytes}
|
||||
|
||||
@return: a 4-tuple of hostname data (L{bytes}), ssh key type (L{bytes}), key
|
||||
(L{Key}), and comment (L{bytes} or L{None}). The hostname data is
|
||||
simply the beginning of the line up to the first occurrence of
|
||||
whitespace.
|
||||
@rtype: L{tuple}
|
||||
"""
|
||||
elements = string.split(None, 2)
|
||||
if len(elements) != 3:
|
||||
raise InvalidEntry()
|
||||
hostnames, keyType, keyAndComment = elements
|
||||
splitkey = keyAndComment.split(None, 1)
|
||||
if len(splitkey) == 2:
|
||||
keyString, comment = splitkey
|
||||
comment = comment.rstrip("\n")
|
||||
else:
|
||||
keyString = splitkey[0]
|
||||
comment = None
|
||||
key = Key.fromString(keyString.decode('base64'))
|
||||
return hostnames, keyType, key, comment
|
||||
|
||||
|
||||
|
||||
class _BaseEntry(object):
|
||||
"""
|
||||
Abstract base of both hashed and non-hashed entry objects, since they
|
||||
represent keys and key types the same way.
|
||||
|
||||
@ivar keyType: The type of the key; either ssh-dss or ssh-rsa.
|
||||
@type keyType: L{str}
|
||||
|
||||
@ivar publicKey: The server public key indicated by this line.
|
||||
@type publicKey: L{twisted.conch.ssh.keys.Key}
|
||||
|
||||
@ivar comment: Trailing garbage after the key line.
|
||||
@type comment: L{str}
|
||||
"""
|
||||
|
||||
def __init__(self, keyType, publicKey, comment):
|
||||
self.keyType = keyType
|
||||
self.publicKey = publicKey
|
||||
self.comment = comment
|
||||
|
||||
|
||||
def matchesKey(self, keyObject):
|
||||
"""
|
||||
Check to see if this entry matches a given key object.
|
||||
|
||||
@param keyObject: A public key object to check.
|
||||
@type keyObject: L{Key}
|
||||
|
||||
@return: C{True} if this entry's key matches C{keyObject}, C{False}
|
||||
otherwise.
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
return self.publicKey == keyObject
|
||||
|
||||
|
||||
|
||||
@implementer(IKnownHostEntry)
|
||||
class PlainEntry(_BaseEntry):
|
||||
"""
|
||||
A L{PlainEntry} is a representation of a plain-text entry in a known_hosts
|
||||
file.
|
||||
|
||||
@ivar _hostnames: the list of all host-names associated with this entry.
|
||||
@type _hostnames: L{list} of L{str}
|
||||
"""
|
||||
|
||||
def __init__(self, hostnames, keyType, publicKey, comment):
|
||||
self._hostnames = hostnames
|
||||
super(PlainEntry, self).__init__(keyType, publicKey, comment)
|
||||
|
||||
|
||||
def fromString(cls, string):
|
||||
"""
|
||||
Parse a plain-text entry in a known_hosts file, and return a
|
||||
corresponding L{PlainEntry}.
|
||||
|
||||
@param string: a space-separated string formatted like "hostname
|
||||
key-type base64-key-data comment".
|
||||
|
||||
@type string: L{str}
|
||||
|
||||
@raise DecodeError: if the key is not valid encoded as valid base64.
|
||||
|
||||
@raise InvalidEntry: if the entry does not have the right number of
|
||||
elements and is therefore invalid.
|
||||
|
||||
@raise BadKeyError: if the key, once decoded from base64, is not
|
||||
actually an SSH key.
|
||||
|
||||
@return: an IKnownHostEntry representing the hostname and key in the
|
||||
input line.
|
||||
|
||||
@rtype: L{PlainEntry}
|
||||
"""
|
||||
hostnames, keyType, key, comment = _extractCommon(string)
|
||||
self = cls(hostnames.split(","), keyType, key, comment)
|
||||
return self
|
||||
|
||||
fromString = classmethod(fromString)
|
||||
|
||||
|
||||
def matchesHost(self, hostname):
|
||||
"""
|
||||
Check to see if this entry matches a given hostname.
|
||||
|
||||
@param hostname: A hostname or IP address literal to check against this
|
||||
entry.
|
||||
@type hostname: L{str}
|
||||
|
||||
@return: C{True} if this entry is for the given hostname or IP address,
|
||||
C{False} otherwise.
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
return hostname in self._hostnames
|
||||
|
||||
|
||||
def toString(self):
|
||||
"""
|
||||
Implement L{IKnownHostEntry.toString} by recording the comma-separated
|
||||
hostnames, key type, and base-64 encoded key.
|
||||
|
||||
@return: The string representation of this entry, with unhashed hostname
|
||||
information.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
fields = [','.join(self._hostnames),
|
||||
self.keyType,
|
||||
_b64encode(self.publicKey.blob())]
|
||||
if self.comment is not None:
|
||||
fields.append(self.comment)
|
||||
return ' '.join(fields)
|
||||
|
||||
|
||||
|
||||
@implementer(IKnownHostEntry)
|
||||
class UnparsedEntry(object):
|
||||
"""
|
||||
L{UnparsedEntry} is an entry in a L{KnownHostsFile} which can't actually be
|
||||
parsed; therefore it matches no keys and no hosts.
|
||||
"""
|
||||
|
||||
def __init__(self, string):
|
||||
"""
|
||||
Create an unparsed entry from a line in a known_hosts file which cannot
|
||||
otherwise be parsed.
|
||||
"""
|
||||
self._string = string
|
||||
|
||||
|
||||
def matchesHost(self, hostname):
|
||||
"""
|
||||
Always returns False.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def matchesKey(self, key):
|
||||
"""
|
||||
Always returns False.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
def toString(self):
|
||||
"""
|
||||
Returns the input line, without its newline if one was given.
|
||||
|
||||
@return: The string representation of this entry, almost exactly as was
|
||||
used to initialize this entry but without a trailing newline.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return self._string.rstrip("\n")
|
||||
|
||||
|
||||
|
||||
def _hmacedString(key, string):
|
||||
"""
|
||||
Return the SHA-1 HMAC hash of the given key and string.
|
||||
|
||||
@param key: The HMAC key.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param string: The string to be hashed.
|
||||
@type string: L{bytes}
|
||||
|
||||
@return: The keyed hash value.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
hash = hmac.HMAC(key, digestmod=sha1)
|
||||
hash.update(string)
|
||||
return hash.digest()
|
||||
|
||||
|
||||
|
||||
@implementer(IKnownHostEntry)
|
||||
class HashedEntry(_BaseEntry, FancyEqMixin):
|
||||
"""
|
||||
A L{HashedEntry} is a representation of an entry in a known_hosts file
|
||||
where the hostname has been hashed and salted.
|
||||
|
||||
@ivar _hostSalt: the salt to combine with a hostname for hashing.
|
||||
|
||||
@ivar _hostHash: the hashed representation of the hostname.
|
||||
|
||||
@cvar MAGIC: the 'hash magic' string used to identify a hashed line in a
|
||||
known_hosts file as opposed to a plaintext one.
|
||||
"""
|
||||
|
||||
MAGIC = '|1|'
|
||||
|
||||
compareAttributes = (
|
||||
"_hostSalt", "_hostHash", "keyType", "publicKey", "comment")
|
||||
|
||||
def __init__(self, hostSalt, hostHash, keyType, publicKey, comment):
|
||||
self._hostSalt = hostSalt
|
||||
self._hostHash = hostHash
|
||||
super(HashedEntry, self).__init__(keyType, publicKey, comment)
|
||||
|
||||
|
||||
def fromString(cls, string):
|
||||
"""
|
||||
Load a hashed entry from a string representing a line in a known_hosts
|
||||
file.
|
||||
|
||||
@param string: A complete single line from a I{known_hosts} file,
|
||||
formatted as defined by OpenSSH.
|
||||
@type string: L{bytes}
|
||||
|
||||
@raise DecodeError: if the key, the hostname, or the is not valid
|
||||
encoded as valid base64
|
||||
|
||||
@raise InvalidEntry: if the entry does not have the right number of
|
||||
elements and is therefore invalid, or the host/hash portion contains
|
||||
more items than just the host and hash.
|
||||
|
||||
@raise BadKeyError: if the key, once decoded from base64, is not
|
||||
actually an SSH key.
|
||||
|
||||
@return: The newly created L{HashedEntry} instance, initialized with the
|
||||
information from C{string}.
|
||||
"""
|
||||
stuff, keyType, key, comment = _extractCommon(string)
|
||||
saltAndHash = stuff[len(cls.MAGIC):].split("|")
|
||||
if len(saltAndHash) != 2:
|
||||
raise InvalidEntry()
|
||||
hostSalt, hostHash = saltAndHash
|
||||
self = cls(hostSalt.decode("base64"), hostHash.decode("base64"),
|
||||
keyType, key, comment)
|
||||
return self
|
||||
|
||||
fromString = classmethod(fromString)
|
||||
|
||||
|
||||
def matchesHost(self, hostname):
|
||||
"""
|
||||
Implement L{IKnownHostEntry.matchesHost} to compare the hash of the
|
||||
input to the stored hash.
|
||||
|
||||
@param hostname: A hostname or IP address literal to check against this
|
||||
entry.
|
||||
@type hostname: L{bytes}
|
||||
|
||||
@return: C{True} if this entry is for the given hostname or IP address,
|
||||
C{False} otherwise.
|
||||
@rtype: L{bool}
|
||||
"""
|
||||
return (_hmacedString(self._hostSalt, hostname) == self._hostHash)
|
||||
|
||||
|
||||
def toString(self):
|
||||
"""
|
||||
Implement L{IKnownHostEntry.toString} by base64-encoding the salt, host
|
||||
hash, and key.
|
||||
|
||||
@return: The string representation of this entry, with the hostname part
|
||||
hashed.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
fields = [self.MAGIC + '|'.join([_b64encode(self._hostSalt),
|
||||
_b64encode(self._hostHash)]),
|
||||
self.keyType,
|
||||
_b64encode(self.publicKey.blob())]
|
||||
if self.comment is not None:
|
||||
fields.append(self.comment)
|
||||
return ' '.join(fields)
|
||||
|
||||
|
||||
|
||||
class KnownHostsFile(object):
|
||||
"""
|
||||
A structured representation of an OpenSSH-format ~/.ssh/known_hosts file.
|
||||
|
||||
@ivar _added: A list of L{IKnownHostEntry} providers which have been added
|
||||
to this instance in memory but not yet saved.
|
||||
|
||||
@ivar _clobber: A flag indicating whether the current contents of the save
|
||||
path will be disregarded and potentially overwritten or not. If
|
||||
C{True}, this will be done. If C{False}, entries in the save path will
|
||||
be read and new entries will be saved by appending rather than
|
||||
overwriting.
|
||||
@type _clobber: L{bool}
|
||||
|
||||
@ivar _savePath: See C{savePath} parameter of L{__init__}.
|
||||
"""
|
||||
|
||||
def __init__(self, savePath):
|
||||
"""
|
||||
Create a new, empty KnownHostsFile.
|
||||
|
||||
Unless you want to erase the current contents of C{savePath}, you want
|
||||
to use L{KnownHostsFile.fromPath} instead.
|
||||
|
||||
@param savePath: The L{FilePath} to which to save new entries.
|
||||
@type savePath: L{FilePath}
|
||||
"""
|
||||
self._added = []
|
||||
self._savePath = savePath
|
||||
self._clobber = True
|
||||
|
||||
|
||||
@property
|
||||
def savePath(self):
|
||||
"""
|
||||
@see: C{savePath} parameter of L{__init__}
|
||||
"""
|
||||
return self._savePath
|
||||
|
||||
|
||||
def iterentries(self):
|
||||
"""
|
||||
Iterate over the host entries in this file.
|
||||
|
||||
@return: An iterable the elements of which provide L{IKnownHostEntry}.
|
||||
There is an element for each entry in the file as well as an element
|
||||
for each added but not yet saved entry.
|
||||
@rtype: iterable of L{IKnownHostEntry} providers
|
||||
"""
|
||||
for entry in self._added:
|
||||
yield entry
|
||||
|
||||
if self._clobber:
|
||||
return
|
||||
|
||||
try:
|
||||
fp = self._savePath.open()
|
||||
except IOError:
|
||||
return
|
||||
|
||||
try:
|
||||
for line in fp:
|
||||
try:
|
||||
if line.startswith(HashedEntry.MAGIC):
|
||||
entry = HashedEntry.fromString(line)
|
||||
else:
|
||||
entry = PlainEntry.fromString(line)
|
||||
except (DecodeError, InvalidEntry, BadKeyError):
|
||||
entry = UnparsedEntry(line)
|
||||
yield entry
|
||||
finally:
|
||||
fp.close()
|
||||
|
||||
|
||||
def hasHostKey(self, hostname, key):
|
||||
"""
|
||||
Check for an entry with matching hostname and key.
|
||||
|
||||
@param hostname: A hostname or IP address literal to check for.
|
||||
@type hostname: L{bytes}
|
||||
|
||||
@param key: The public key to check for.
|
||||
@type key: L{Key}
|
||||
|
||||
@return: C{True} if the given hostname and key are present in this file,
|
||||
C{False} if they are not.
|
||||
@rtype: L{bool}
|
||||
|
||||
@raise HostKeyChanged: if the host key found for the given hostname
|
||||
does not match the given key.
|
||||
"""
|
||||
for lineidx, entry in enumerate(self.iterentries(), -len(self._added)):
|
||||
if entry.matchesHost(hostname):
|
||||
if entry.matchesKey(key):
|
||||
return True
|
||||
else:
|
||||
# Notice that lineidx is 0-based but HostKeyChanged.lineno
|
||||
# is 1-based.
|
||||
if lineidx < 0:
|
||||
line = None
|
||||
path = None
|
||||
else:
|
||||
line = lineidx + 1
|
||||
path = self._savePath
|
||||
raise HostKeyChanged(entry, path, line)
|
||||
return False
|
||||
|
||||
|
||||
def verifyHostKey(self, ui, hostname, ip, key):
|
||||
"""
|
||||
Verify the given host key for the given IP and host, asking for
|
||||
confirmation from, and notifying, the given UI about changes to this
|
||||
file.
|
||||
|
||||
@param ui: The user interface to request an IP address from.
|
||||
|
||||
@param hostname: The hostname that the user requested to connect to.
|
||||
|
||||
@param ip: The string representation of the IP address that is actually
|
||||
being connected to.
|
||||
|
||||
@param key: The public key of the server.
|
||||
|
||||
@return: a L{Deferred} that fires with True when the key has been
|
||||
verified, or fires with an errback when the key either cannot be
|
||||
verified or has changed.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
hhk = defer.maybeDeferred(self.hasHostKey, hostname, key)
|
||||
def gotHasKey(result):
|
||||
if result:
|
||||
if not self.hasHostKey(ip, key):
|
||||
ui.warn("Warning: Permanently added the %s host key for "
|
||||
"IP address '%s' to the list of known hosts." %
|
||||
(key.type(), ip))
|
||||
self.addHostKey(ip, key)
|
||||
self.save()
|
||||
return result
|
||||
else:
|
||||
def promptResponse(response):
|
||||
if response:
|
||||
self.addHostKey(hostname, key)
|
||||
self.addHostKey(ip, key)
|
||||
self.save()
|
||||
return response
|
||||
else:
|
||||
raise UserRejectedKey()
|
||||
proceed = ui.prompt(
|
||||
"The authenticity of host '%s (%s)' "
|
||||
"can't be established.\n"
|
||||
"RSA key fingerprint is %s.\n"
|
||||
"Are you sure you want to continue connecting (yes/no)? " %
|
||||
(hostname, ip, key.fingerprint()))
|
||||
return proceed.addCallback(promptResponse)
|
||||
return hhk.addCallback(gotHasKey)
|
||||
|
||||
|
||||
def addHostKey(self, hostname, key):
|
||||
"""
|
||||
Add a new L{HashedEntry} to the key database.
|
||||
|
||||
Note that you still need to call L{KnownHostsFile.save} if you wish
|
||||
these changes to be persisted.
|
||||
|
||||
@param hostname: A hostname or IP address literal to associate with the
|
||||
new entry.
|
||||
@type hostname: L{bytes}
|
||||
|
||||
@param key: The public key to associate with the new entry.
|
||||
@type key: L{Key}
|
||||
|
||||
@return: The L{HashedEntry} that was added.
|
||||
@rtype: L{HashedEntry}
|
||||
"""
|
||||
salt = secureRandom(20)
|
||||
keyType = "ssh-" + key.type().lower()
|
||||
entry = HashedEntry(salt, _hmacedString(salt, hostname),
|
||||
keyType, key, None)
|
||||
self._added.append(entry)
|
||||
return entry
|
||||
|
||||
|
||||
def save(self):
|
||||
"""
|
||||
Save this L{KnownHostsFile} to the path it was loaded from.
|
||||
"""
|
||||
p = self._savePath.parent()
|
||||
if not p.isdir():
|
||||
p.makedirs()
|
||||
|
||||
if self._clobber:
|
||||
mode = "w"
|
||||
else:
|
||||
mode = "a"
|
||||
|
||||
with self._savePath.open(mode) as hostsFileObj:
|
||||
if self._added:
|
||||
hostsFileObj.write(
|
||||
"\n".join([entry.toString() for entry in self._added]) +
|
||||
"\n")
|
||||
self._added = []
|
||||
self._clobber = False
|
||||
|
||||
|
||||
def fromPath(cls, path):
|
||||
"""
|
||||
Create a new L{KnownHostsFile}, potentially reading existing known
|
||||
hosts information from the given file.
|
||||
|
||||
@param path: A path object to use for both reading contents from and
|
||||
later saving to. If no file exists at this path, it is not an
|
||||
error; a L{KnownHostsFile} with no entries is returned.
|
||||
@type path: L{FilePath}
|
||||
|
||||
@return: A L{KnownHostsFile} initialized with entries from C{path}.
|
||||
@rtype: L{KnownHostsFile}
|
||||
"""
|
||||
knownHosts = cls(path)
|
||||
knownHosts._clobber = False
|
||||
return knownHosts
|
||||
|
||||
fromPath = classmethod(fromPath)
|
||||
|
||||
|
||||
|
||||
class ConsoleUI(object):
|
||||
"""
|
||||
A UI object that can ask true/false questions and post notifications on the
|
||||
console, to be used during key verification.
|
||||
"""
|
||||
def __init__(self, opener):
|
||||
"""
|
||||
@param opener: A no-argument callable which should open a console
|
||||
binary-mode file-like object to be used for reading and writing.
|
||||
This initializes the C{opener} attribute.
|
||||
@type opener: callable taking no arguments and returning a read/write
|
||||
file-like object
|
||||
"""
|
||||
self.opener = opener
|
||||
|
||||
|
||||
def prompt(self, text):
|
||||
"""
|
||||
Write the given text as a prompt to the console output, then read a
|
||||
result from the console input.
|
||||
|
||||
@param text: Something to present to a user to solicit a yes or no
|
||||
response.
|
||||
@type text: L{bytes}
|
||||
|
||||
@return: a L{Deferred} which fires with L{True} when the user answers
|
||||
'yes' and L{False} when the user answers 'no'. It may errback if
|
||||
there were any I/O errors.
|
||||
"""
|
||||
d = defer.succeed(None)
|
||||
def body(ignored):
|
||||
f = self.opener()
|
||||
f.write(text)
|
||||
while True:
|
||||
answer = f.readline().strip().lower()
|
||||
if answer == 'yes':
|
||||
f.close()
|
||||
return True
|
||||
elif answer == 'no':
|
||||
f.close()
|
||||
return False
|
||||
else:
|
||||
f.write("Please type 'yes' or 'no': ")
|
||||
return d.addCallback(body)
|
||||
|
||||
|
||||
def warn(self, text):
|
||||
"""
|
||||
Notify the user (non-interactively) of the provided text, by writing it
|
||||
to the console.
|
||||
|
||||
@param text: Some information the user is to be made aware of.
|
||||
@type text: L{bytes}
|
||||
"""
|
||||
try:
|
||||
f = self.opener()
|
||||
f.write(text)
|
||||
f.close()
|
||||
except:
|
||||
log.err()
|
@ -0,0 +1,96 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
from twisted.conch.ssh.transport import SSHClientTransport, SSHCiphers
|
||||
from twisted.python import usage
|
||||
|
||||
import sys
|
||||
|
||||
class ConchOptions(usage.Options):
|
||||
|
||||
optParameters = [['user', 'l', None, 'Log in using this user name.'],
|
||||
['identity', 'i', None],
|
||||
['ciphers', 'c', None],
|
||||
['macs', 'm', None],
|
||||
['port', 'p', None, 'Connect to this port. Server must be on the same port.'],
|
||||
['option', 'o', None, 'Ignored OpenSSH options'],
|
||||
['host-key-algorithms', '', None],
|
||||
['known-hosts', '', None, 'File to check for host keys'],
|
||||
['user-authentications', '', None, 'Types of user authentications to use.'],
|
||||
['logfile', '', None, 'File to log to, or - for stdout'],
|
||||
]
|
||||
|
||||
optFlags = [['version', 'V', 'Display version number only.'],
|
||||
['compress', 'C', 'Enable compression.'],
|
||||
['log', 'v', 'Enable logging (defaults to stderr)'],
|
||||
['nox11', 'x', 'Disable X11 connection forwarding (default)'],
|
||||
['agent', 'A', 'Enable authentication agent forwarding'],
|
||||
['noagent', 'a', 'Disable authentication agent forwarding (default)'],
|
||||
['reconnect', 'r', 'Reconnect to the server if the connection is lost.'],
|
||||
]
|
||||
|
||||
compData = usage.Completions(
|
||||
mutuallyExclusive=[("agent", "noagent")],
|
||||
optActions={
|
||||
"user": usage.CompleteUsernames(),
|
||||
"ciphers": usage.CompleteMultiList(
|
||||
SSHCiphers.cipherMap.keys(),
|
||||
descr='ciphers to choose from'),
|
||||
"macs": usage.CompleteMultiList(
|
||||
SSHCiphers.macMap.keys(),
|
||||
descr='macs to choose from'),
|
||||
"host-key-algorithms": usage.CompleteMultiList(
|
||||
SSHClientTransport.supportedPublicKeys,
|
||||
descr='host key algorithms to choose from'),
|
||||
#"user-authentications": usage.CompleteMultiList(?
|
||||
# descr='user authentication types' ),
|
||||
},
|
||||
extraActions=[usage.CompleteUserAtHost(),
|
||||
usage.Completer(descr="command"),
|
||||
usage.Completer(descr='argument',
|
||||
repeat=True)]
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
usage.Options.__init__(self, *args, **kw)
|
||||
self.identitys = []
|
||||
self.conns = None
|
||||
|
||||
def opt_identity(self, i):
|
||||
"""Identity for public-key authentication"""
|
||||
self.identitys.append(i)
|
||||
|
||||
def opt_ciphers(self, ciphers):
|
||||
"Select encryption algorithms"
|
||||
ciphers = ciphers.split(',')
|
||||
for cipher in ciphers:
|
||||
if not SSHCiphers.cipherMap.has_key(cipher):
|
||||
sys.exit("Unknown cipher type '%s'" % cipher)
|
||||
self['ciphers'] = ciphers
|
||||
|
||||
|
||||
def opt_macs(self, macs):
|
||||
"Specify MAC algorithms"
|
||||
macs = macs.split(',')
|
||||
for mac in macs:
|
||||
if not SSHCiphers.macMap.has_key(mac):
|
||||
sys.exit("Unknown mac type '%s'" % mac)
|
||||
self['macs'] = macs
|
||||
|
||||
def opt_host_key_algorithms(self, hkas):
|
||||
"Select host key algorithms"
|
||||
hkas = hkas.split(',')
|
||||
for hka in hkas:
|
||||
if hka not in SSHClientTransport.supportedPublicKeys:
|
||||
sys.exit("Unknown host key type '%s'" % hka)
|
||||
self['host-key-algorithms'] = hkas
|
||||
|
||||
def opt_user_authentications(self, uas):
|
||||
"Choose how to authenticate to the remote server"
|
||||
self['user-authentications'] = uas.split(',')
|
||||
|
||||
# def opt_compress(self):
|
||||
# "Enable compression"
|
||||
# self.enableCompression = 1
|
||||
# SSHClientTransport.supportedCompressions[0:1] = ['zlib']
|
@ -0,0 +1,832 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_endpoints -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Endpoint implementations of various SSH interactions.
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
'AuthenticationFailed', 'SSHCommandAddress', 'SSHCommandClientEndpoint']
|
||||
|
||||
from struct import unpack
|
||||
from os.path import expanduser
|
||||
|
||||
from zope.interface import Interface, implementer
|
||||
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.internet.error import ConnectionDone, ProcessTerminated
|
||||
from twisted.internet.interfaces import IStreamClientEndpoint
|
||||
from twisted.internet.protocol import Factory
|
||||
from twisted.internet.defer import Deferred, succeed, CancelledError
|
||||
from twisted.internet.endpoints import TCP4ClientEndpoint, connectProtocol
|
||||
|
||||
from twisted.conch.ssh.keys import Key
|
||||
from twisted.conch.ssh.common import NS
|
||||
from twisted.conch.ssh.transport import SSHClientTransport
|
||||
from twisted.conch.ssh.connection import SSHConnection
|
||||
from twisted.conch.ssh.userauth import SSHUserAuthClient
|
||||
from twisted.conch.ssh.channel import SSHChannel
|
||||
from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
|
||||
from twisted.conch.client.agent import SSHAgentClient
|
||||
from twisted.conch.client.default import _KNOWN_HOSTS
|
||||
|
||||
|
||||
class AuthenticationFailed(Exception):
|
||||
"""
|
||||
An SSH session could not be established because authentication was not
|
||||
successful.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
# This should be public. See #6541.
|
||||
class _ISSHConnectionCreator(Interface):
|
||||
"""
|
||||
An L{_ISSHConnectionCreator} knows how to create SSH connections somehow.
|
||||
"""
|
||||
def secureConnection():
|
||||
"""
|
||||
Return a new, connected, secured, but not yet authenticated instance of
|
||||
L{twisted.conch.ssh.transport.SSHServerTransport} or
|
||||
L{twisted.conch.ssh.transport.SSHClientTransport}.
|
||||
"""
|
||||
|
||||
|
||||
def cleanupConnection(connection, immediate):
|
||||
"""
|
||||
Perform cleanup necessary for a connection object previously returned
|
||||
from this creator's C{secureConnection} method.
|
||||
|
||||
@param connection: An L{twisted.conch.ssh.transport.SSHServerTransport}
|
||||
or L{twisted.conch.ssh.transport.SSHClientTransport} returned by a
|
||||
previous call to C{secureConnection}. It is no longer needed by the
|
||||
caller of that method and may be closed or otherwise cleaned up as
|
||||
necessary.
|
||||
|
||||
@param immediate: If C{True} don't wait for any network communication,
|
||||
just close the connection immediately and as aggressively as
|
||||
necessary.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class SSHCommandAddress(object):
|
||||
"""
|
||||
An L{SSHCommandAddress} instance represents the address of an SSH server, a
|
||||
username which was used to authenticate with that server, and a command
|
||||
which was run there.
|
||||
|
||||
@ivar server: See L{__init__}
|
||||
@ivar username: See L{__init__}
|
||||
@ivar command: See L{__init__}
|
||||
"""
|
||||
def __init__(self, server, username, command):
|
||||
"""
|
||||
@param server: The address of the SSH server on which the command is
|
||||
running.
|
||||
@type server: L{IAddress} provider
|
||||
|
||||
@param username: An authentication username which was used to
|
||||
authenticate against the server at the given address.
|
||||
@type username: L{bytes}
|
||||
|
||||
@param command: A command which was run in a session channel on the
|
||||
server at the given address.
|
||||
@type command: L{bytes}
|
||||
"""
|
||||
self.server = server
|
||||
self.username = username
|
||||
self.command = command
|
||||
|
||||
|
||||
|
||||
class _CommandChannel(SSHChannel):
|
||||
"""
|
||||
A L{_CommandChannel} executes a command in a session channel and connects
|
||||
its input and output to an L{IProtocol} provider.
|
||||
|
||||
@ivar _creator: See L{__init__}
|
||||
@ivar _command: See L{__init__}
|
||||
@ivar _protocolFactory: See L{__init__}
|
||||
@ivar _commandConnected: See L{__init__}
|
||||
@ivar _protocol: An L{IProtocol} provider created using C{_protocolFactory}
|
||||
which is hooked up to the running command's input and output streams.
|
||||
"""
|
||||
name = b'session'
|
||||
|
||||
def __init__(self, creator, command, protocolFactory, commandConnected):
|
||||
"""
|
||||
@param creator: The L{_ISSHConnectionCreator} provider which was used
|
||||
to get the connection which this channel exists on.
|
||||
@type creator: L{_ISSHConnectionCreator} provider
|
||||
|
||||
@param command: The command to be executed.
|
||||
@type command: L{bytes}
|
||||
|
||||
@param protocolFactory: A client factory to use to build a L{IProtocol}
|
||||
provider to use to associate with the running command.
|
||||
|
||||
@param commandConnected: A L{Deferred} to use to signal that execution
|
||||
of the command has failed or that it has succeeded and the command
|
||||
is now running.
|
||||
@type commandConnected: L{Deferred}
|
||||
"""
|
||||
SSHChannel.__init__(self)
|
||||
self._creator = creator
|
||||
self._command = command
|
||||
self._protocolFactory = protocolFactory
|
||||
self._commandConnected = commandConnected
|
||||
self._reason = None
|
||||
|
||||
|
||||
def openFailed(self, reason):
|
||||
"""
|
||||
When the request to open a new channel to run this command in fails,
|
||||
fire the C{commandConnected} deferred with a failure indicating that.
|
||||
"""
|
||||
self._commandConnected.errback(reason)
|
||||
|
||||
|
||||
def channelOpen(self, ignored):
|
||||
"""
|
||||
When the request to open a new channel to run this command in succeeds,
|
||||
issue an C{"exec"} request to run the command.
|
||||
"""
|
||||
command = self.conn.sendRequest(
|
||||
self, 'exec', NS(self._command), wantReply=True)
|
||||
command.addCallbacks(self._execSuccess, self._execFailure)
|
||||
|
||||
|
||||
def _execFailure(self, reason):
|
||||
"""
|
||||
When the request to execute the command in this channel fails, fire the
|
||||
C{commandConnected} deferred with a failure indicating this.
|
||||
|
||||
@param reason: The cause of the command execution failure.
|
||||
@type reason: L{Failure}
|
||||
"""
|
||||
self._commandConnected.errback(reason)
|
||||
|
||||
|
||||
def _execSuccess(self, ignored):
|
||||
"""
|
||||
When the request to execute the command in this channel succeeds, use
|
||||
C{protocolFactory} to build a protocol to handle the command's input and
|
||||
output and connect the protocol to a transport representing those
|
||||
streams.
|
||||
|
||||
Also fire C{commandConnected} with the created protocol after it is
|
||||
connected to its transport.
|
||||
|
||||
@param ignored: The (ignored) result of the execute request
|
||||
"""
|
||||
self._protocol = self._protocolFactory.buildProtocol(
|
||||
SSHCommandAddress(
|
||||
self.conn.transport.transport.getPeer(),
|
||||
self.conn.transport.creator.username,
|
||||
self.conn.transport.creator.command))
|
||||
self._protocol.makeConnection(self)
|
||||
self._commandConnected.callback(self._protocol)
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
When the command's stdout data arrives over the channel, deliver it to
|
||||
the protocol instance.
|
||||
|
||||
@param data: The bytes from the command's stdout.
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
self._protocol.dataReceived(data)
|
||||
|
||||
|
||||
def request_exit_status(self, data):
|
||||
"""
|
||||
When the server sends the command's exit status, record it for later
|
||||
delivery to the protocol.
|
||||
|
||||
@param data: The network-order four byte representation of the exit
|
||||
status of the command.
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
(status,) = unpack('>L', data)
|
||||
if status != 0:
|
||||
self._reason = ProcessTerminated(status, None, None)
|
||||
|
||||
|
||||
def request_exit_signal(self, data):
|
||||
"""
|
||||
When the server sends the command's exit status, record it for later
|
||||
delivery to the protocol.
|
||||
|
||||
@param data: The network-order four byte representation of the exit
|
||||
signal of the command.
|
||||
@type data: L{bytes}
|
||||
"""
|
||||
(signal,) = unpack('>L', data)
|
||||
self._reason = ProcessTerminated(None, signal, None)
|
||||
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
When the channel closes, deliver disconnection notification to the
|
||||
protocol.
|
||||
"""
|
||||
self._creator.cleanupConnection(self.conn, False)
|
||||
if self._reason is None:
|
||||
reason = ConnectionDone("ssh channel closed")
|
||||
else:
|
||||
reason = self._reason
|
||||
self._protocol.connectionLost(Failure(reason))
|
||||
|
||||
|
||||
|
||||
class _ConnectionReady(SSHConnection):
|
||||
"""
|
||||
L{_ConnectionReady} is an L{SSHConnection} (an SSH service) which only
|
||||
propagates the I{serviceStarted} event to a L{Deferred} to be handled
|
||||
elsewhere.
|
||||
"""
|
||||
def __init__(self, ready):
|
||||
"""
|
||||
@param ready: A L{Deferred} which should be fired when
|
||||
I{serviceStarted} happens.
|
||||
"""
|
||||
SSHConnection.__init__(self)
|
||||
self._ready = ready
|
||||
|
||||
|
||||
def serviceStarted(self):
|
||||
"""
|
||||
When the SSH I{connection} I{service} this object represents is ready to
|
||||
be used, fire the C{connectionReady} L{Deferred} to publish that event
|
||||
to some other interested party.
|
||||
|
||||
"""
|
||||
self._ready.callback(self)
|
||||
del self._ready
|
||||
|
||||
|
||||
|
||||
class _UserAuth(SSHUserAuthClient):
|
||||
"""
|
||||
L{_UserAuth} implements the client part of SSH user authentication in the
|
||||
convenient way a user might expect if they are familiar with the
|
||||
interactive I{ssh} command line client.
|
||||
|
||||
L{_UserAuth} supports key-based authentication, password-based
|
||||
authentication, and delegating authentication to an agent.
|
||||
"""
|
||||
password = None
|
||||
keys = None
|
||||
agent = None
|
||||
|
||||
def getPublicKey(self):
|
||||
"""
|
||||
Retrieve the next public key object to offer to the server, possibly
|
||||
delegating to an authentication agent if there is one.
|
||||
|
||||
@return: The public part of a key pair that could be used to
|
||||
authenticate with the server, or C{None} if there are no more public
|
||||
keys to try.
|
||||
@rtype: L{twisted.conch.ssh.keys.Key} or L{types.NoneType}
|
||||
"""
|
||||
if self.agent is not None:
|
||||
return self.agent.getPublicKey()
|
||||
|
||||
if self.keys:
|
||||
self.key = self.keys.pop(0)
|
||||
else:
|
||||
self.key = None
|
||||
return self.key.public()
|
||||
|
||||
|
||||
def signData(self, publicKey, signData):
|
||||
"""
|
||||
Extend the base signing behavior by using an SSH agent to sign the
|
||||
data, if one is available.
|
||||
|
||||
@type publicKey: L{Key}
|
||||
@type signData: C{str}
|
||||
"""
|
||||
if self.agent is not None:
|
||||
return self.agent.signData(publicKey.blob(), signData)
|
||||
else:
|
||||
return SSHUserAuthClient.signData(self, publicKey, signData)
|
||||
|
||||
|
||||
def getPrivateKey(self):
|
||||
"""
|
||||
Get the private part of a key pair to use for authentication. The key
|
||||
corresponds to the public part most recently returned from
|
||||
C{getPublicKey}.
|
||||
|
||||
@return: A L{Deferred} which fires with the private key.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return succeed(self.key)
|
||||
|
||||
|
||||
def getPassword(self):
|
||||
"""
|
||||
Get the password to use for authentication.
|
||||
|
||||
@return: A L{Deferred} which fires with the password, or C{None} if the
|
||||
password was not specified.
|
||||
"""
|
||||
if self.password is None:
|
||||
return
|
||||
return succeed(self.password)
|
||||
|
||||
|
||||
def ssh_USERAUTH_SUCCESS(self, packet):
|
||||
"""
|
||||
Handle user authentication success in the normal way, but also make a
|
||||
note of the state change on the L{_CommandTransport}.
|
||||
"""
|
||||
self.transport._state = b'CHANNELLING'
|
||||
return SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
|
||||
|
||||
|
||||
|
||||
class _CommandTransport(SSHClientTransport):
|
||||
"""
|
||||
L{_CommandTransport} is an SSH client I{transport} which includes a host key
|
||||
verification step before it will proceed to secure the connection.
|
||||
|
||||
L{_CommandTransport} also knows how to set up a connection to an
|
||||
authentication agent if it is told where it can connect to one.
|
||||
"""
|
||||
# STARTING -> SECURING -> AUTHENTICATING -> CHANNELLING -> RUNNING
|
||||
_state = b'STARTING'
|
||||
|
||||
_hostKeyFailure = None
|
||||
|
||||
|
||||
def __init__(self, creator):
|
||||
"""
|
||||
@param creator: The L{_NewConnectionHelper} that created this
|
||||
connection.
|
||||
|
||||
@type creator: L{_NewConnectionHelper}.
|
||||
"""
|
||||
self.connectionReady = Deferred(
|
||||
lambda d: self.transport.abortConnection())
|
||||
# Clear the reference to that deferred to help the garbage collector
|
||||
# and to signal to other parts of this implementation (in particular
|
||||
# connectionLost) that it has already been fired and does not need to
|
||||
# be fired again.
|
||||
def readyFired(result):
|
||||
self.connectionReady = None
|
||||
return result
|
||||
self.connectionReady.addBoth(readyFired)
|
||||
self.creator = creator
|
||||
|
||||
|
||||
def verifyHostKey(self, hostKey, fingerprint):
|
||||
"""
|
||||
Ask the L{KnownHostsFile} provider available on the factory which
|
||||
created this protocol this protocol to verify the given host key.
|
||||
|
||||
@return: A L{Deferred} which fires with the result of
|
||||
L{KnownHostsFile.verifyHostKey}.
|
||||
"""
|
||||
hostname = self.creator.hostname
|
||||
ip = self.transport.getPeer().host
|
||||
|
||||
self._state = b'SECURING'
|
||||
d = self.creator.knownHosts.verifyHostKey(
|
||||
self.creator.ui, hostname, ip, Key.fromString(hostKey))
|
||||
d.addErrback(self._saveHostKeyFailure)
|
||||
return d
|
||||
|
||||
|
||||
def _saveHostKeyFailure(self, reason):
|
||||
"""
|
||||
When host key verification fails, record the reason for the failure in
|
||||
order to fire a L{Deferred} with it later.
|
||||
|
||||
@param reason: The cause of the host key verification failure.
|
||||
@type reason: L{Failure}
|
||||
|
||||
@return: C{reason}
|
||||
@rtype: L{Failure}
|
||||
"""
|
||||
self._hostKeyFailure = reason
|
||||
return reason
|
||||
|
||||
|
||||
def connectionSecure(self):
|
||||
"""
|
||||
When the connection is secure, start the authentication process.
|
||||
"""
|
||||
self._state = b'AUTHENTICATING'
|
||||
|
||||
command = _ConnectionReady(self.connectionReady)
|
||||
|
||||
userauth = _UserAuth(self.creator.username, command)
|
||||
userauth.password = self.creator.password
|
||||
if self.creator.keys:
|
||||
userauth.keys = list(self.creator.keys)
|
||||
|
||||
if self.creator.agentEndpoint is not None:
|
||||
d = self._connectToAgent(userauth, self.creator.agentEndpoint)
|
||||
else:
|
||||
d = succeed(None)
|
||||
|
||||
def maybeGotAgent(ignored):
|
||||
self.requestService(userauth)
|
||||
d.addBoth(maybeGotAgent)
|
||||
|
||||
|
||||
def _connectToAgent(self, userauth, endpoint):
|
||||
"""
|
||||
Set up a connection to the authentication agent and trigger its
|
||||
initialization.
|
||||
|
||||
@param userauth: The L{_UserAuth} instance which is in charge of the
|
||||
overall authentication process.
|
||||
@type userauth: L{_UserAuth}
|
||||
|
||||
@param endpoint: An endpoint which can be used to connect to the
|
||||
authentication agent.
|
||||
@type endpoint: L{IStreamClientEndpoint} provider
|
||||
|
||||
@return: A L{Deferred} which fires when the agent connection is ready
|
||||
for use.
|
||||
"""
|
||||
factory = Factory()
|
||||
factory.protocol = SSHAgentClient
|
||||
d = endpoint.connect(factory)
|
||||
def connected(agent):
|
||||
userauth.agent = agent
|
||||
return agent.getPublicKeys()
|
||||
d.addCallback(connected)
|
||||
return d
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
When the underlying connection to the SSH server is lost, if there were
|
||||
any connection setup errors, propagate them.
|
||||
"""
|
||||
if self._state == b'RUNNING' or self.connectionReady is None:
|
||||
return
|
||||
if self._state == b'SECURING' and self._hostKeyFailure is not None:
|
||||
reason = self._hostKeyFailure
|
||||
elif self._state == b'AUTHENTICATING':
|
||||
reason = Failure(
|
||||
AuthenticationFailed("Connection lost while authenticating"))
|
||||
self.connectionReady.errback(reason)
|
||||
|
||||
|
||||
|
||||
@implementer(IStreamClientEndpoint)
|
||||
class SSHCommandClientEndpoint(object):
|
||||
"""
|
||||
L{SSHCommandClientEndpoint} exposes the command-executing functionality of
|
||||
SSH servers.
|
||||
|
||||
L{SSHCommandClientEndpoint} can set up a new SSH connection, authenticate
|
||||
it in any one of a number of different ways (keys, passwords, agents),
|
||||
launch a command over that connection and then associate its input and
|
||||
output with a protocol.
|
||||
|
||||
It can also re-use an existing, already-authenticated SSH connection
|
||||
(perhaps one which already has some SSH channels being used for other
|
||||
purposes). In this case it creates a new SSH channel to use to execute the
|
||||
command. Notably this means it supports multiplexing several different
|
||||
command invocations over a single SSH connection.
|
||||
"""
|
||||
|
||||
def __init__(self, creator, command):
|
||||
"""
|
||||
@param creator: An L{_ISSHConnectionCreator} provider which will be
|
||||
used to set up the SSH connection which will be used to run a
|
||||
command.
|
||||
@type creator: L{_ISSHConnectionCreator} provider
|
||||
|
||||
@param command: The command line to execute on the SSH server. This
|
||||
byte string is interpreted by a shell on the SSH server, so it may
|
||||
have a value like C{"ls /"}. Take care when trying to run a command
|
||||
like C{"/Volumes/My Stuff/a-program"} - spaces (and other special
|
||||
bytes) may require escaping.
|
||||
@type command: L{bytes}
|
||||
|
||||
"""
|
||||
self._creator = creator
|
||||
self._command = command
|
||||
|
||||
|
||||
@classmethod
|
||||
def newConnection(cls, reactor, command, username, hostname, port=None,
|
||||
keys=None, password=None, agentEndpoint=None,
|
||||
knownHosts=None, ui=None):
|
||||
"""
|
||||
Create and return a new endpoint which will try to create a new
|
||||
connection to an SSH server and run a command over it. It will also
|
||||
close the connection if there are problems leading up to the command
|
||||
being executed, after the command finishes, or if the connection
|
||||
L{Deferred} is cancelled.
|
||||
|
||||
@param reactor: The reactor to use to establish the connection.
|
||||
@type reactor: L{IReactorTCP} provider
|
||||
|
||||
@param command: See L{__init__}'s C{command} argument.
|
||||
|
||||
@param username: The username with which to authenticate to the SSH
|
||||
server.
|
||||
@type username: L{bytes}
|
||||
|
||||
@param hostname: The hostname of the SSH server.
|
||||
@type hostname: L{bytes}
|
||||
|
||||
@param port: The port number of the SSH server. By default, the
|
||||
standard SSH port number is used.
|
||||
@type port: L{int}
|
||||
|
||||
@param keys: Private keys with which to authenticate to the SSH server,
|
||||
if key authentication is to be attempted (otherwise C{None}).
|
||||
@type keys: L{list} of L{Key}
|
||||
|
||||
@param password: The password with which to authenticate to the SSH
|
||||
server, if password authentication is to be attempted (otherwise
|
||||
C{None}).
|
||||
@type password: L{bytes} or L{types.NoneType}
|
||||
|
||||
@param agentEndpoint: An L{IStreamClientEndpoint} provider which may be
|
||||
used to connect to an SSH agent, if one is to be used to help with
|
||||
authentication.
|
||||
@type agentEndpoint: L{IStreamClientEndpoint} provider
|
||||
|
||||
@param knownHosts: The currently known host keys, used to check the
|
||||
host key presented by the server we actually connect to.
|
||||
@type knownHosts: L{KnownHostsFile}
|
||||
|
||||
@param ui: An object for interacting with users to make decisions about
|
||||
whether to accept the server host keys. If C{None}, a L{ConsoleUI}
|
||||
connected to /dev/tty will be used; if /dev/tty is unavailable, an
|
||||
object which answers C{b"no"} to all prompts will be used.
|
||||
@type ui: L{NoneType} or L{ConsoleUI}
|
||||
|
||||
@return: A new instance of C{cls} (probably
|
||||
L{SSHCommandClientEndpoint}).
|
||||
"""
|
||||
helper = _NewConnectionHelper(
|
||||
reactor, hostname, port, command, username, keys, password,
|
||||
agentEndpoint, knownHosts, ui)
|
||||
return cls(helper, command)
|
||||
|
||||
|
||||
@classmethod
|
||||
def existingConnection(cls, connection, command):
|
||||
"""
|
||||
Create and return a new endpoint which will try to open a new channel on
|
||||
an existing SSH connection and run a command over it. It will B{not}
|
||||
close the connection if there is a problem executing the command or
|
||||
after the command finishes.
|
||||
|
||||
@param connection: An existing connection to an SSH server.
|
||||
@type connection: L{SSHConnection}
|
||||
|
||||
@param command: See L{SSHCommandClientEndpoint.newConnection}'s
|
||||
C{command} parameter.
|
||||
@type command: L{bytes}
|
||||
|
||||
@return: A new instance of C{cls} (probably
|
||||
L{SSHCommandClientEndpoint}).
|
||||
"""
|
||||
helper = _ExistingConnectionHelper(connection)
|
||||
return cls(helper, command)
|
||||
|
||||
|
||||
def connect(self, protocolFactory):
|
||||
"""
|
||||
Set up an SSH connection, use a channel from that connection to launch
|
||||
a command, and hook the stdin and stdout of that command up as a
|
||||
transport for a protocol created by the given factory.
|
||||
|
||||
@param protocolFactory: A L{Factory} to use to create the protocol
|
||||
which will be connected to the stdin and stdout of the command on
|
||||
the SSH server.
|
||||
|
||||
@return: A L{Deferred} which will fire with an error if the connection
|
||||
cannot be set up for any reason or with the protocol instance
|
||||
created by C{protocolFactory} once it has been connected to the
|
||||
command.
|
||||
"""
|
||||
d = self._creator.secureConnection()
|
||||
d.addCallback(self._executeCommand, protocolFactory)
|
||||
return d
|
||||
|
||||
|
||||
def _executeCommand(self, connection, protocolFactory):
|
||||
"""
|
||||
Given a secured SSH connection, try to execute a command in a new
|
||||
channel created on it and associate the result with a protocol from the
|
||||
given factory.
|
||||
|
||||
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
|
||||
C{connection} parameter.
|
||||
|
||||
@param protocolFactory: See L{SSHCommandClientEndpoint.connect}'s
|
||||
C{protocolFactory} parameter.
|
||||
|
||||
@return: See L{SSHCommandClientEndpoint.connect}'s return value.
|
||||
"""
|
||||
commandConnected = Deferred()
|
||||
def disconnectOnFailure(passthrough):
|
||||
# Close the connection immediately in case of cancellation, since
|
||||
# that implies user wants it gone immediately (e.g. a timeout):
|
||||
immediate = passthrough.check(CancelledError)
|
||||
self._creator.cleanupConnection(connection, immediate)
|
||||
return passthrough
|
||||
commandConnected.addErrback(disconnectOnFailure)
|
||||
|
||||
channel = _CommandChannel(
|
||||
self._creator, self._command, protocolFactory, commandConnected)
|
||||
connection.openChannel(channel)
|
||||
return commandConnected
|
||||
|
||||
|
||||
|
||||
class _ReadFile(object):
|
||||
"""
|
||||
A weakly file-like object which can be used with L{KnownHostsFile} to
|
||||
respond in the negative to all prompts for decisions.
|
||||
"""
|
||||
def __init__(self, contents):
|
||||
"""
|
||||
@param contents: L{bytes} which will be returned from every C{readline}
|
||||
call.
|
||||
"""
|
||||
self._contents = contents
|
||||
|
||||
|
||||
def write(self, data):
|
||||
"""
|
||||
No-op.
|
||||
|
||||
@param data: ignored
|
||||
"""
|
||||
|
||||
|
||||
def readline(self, count=-1):
|
||||
"""
|
||||
Always give back the byte string that this L{_ReadFile} was initialized
|
||||
with.
|
||||
|
||||
@param count: ignored
|
||||
|
||||
@return: A fixed byte-string.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return self._contents
|
||||
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
No-op.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@implementer(_ISSHConnectionCreator)
|
||||
class _NewConnectionHelper(object):
|
||||
"""
|
||||
L{_NewConnectionHelper} implements L{_ISSHConnectionCreator} by
|
||||
establishing a brand new SSH connection, securing it, and authenticating.
|
||||
"""
|
||||
_KNOWN_HOSTS = _KNOWN_HOSTS
|
||||
port = 22
|
||||
|
||||
def __init__(self, reactor, hostname, port, command, username, keys,
|
||||
password, agentEndpoint, knownHosts, ui,
|
||||
tty=FilePath(b"/dev/tty")):
|
||||
"""
|
||||
@param tty: The path of the tty device to use in case C{ui} is C{None}.
|
||||
@type tty: L{FilePath}
|
||||
|
||||
@see: L{SSHCommandClientEndpoint.newConnection}
|
||||
"""
|
||||
self.reactor = reactor
|
||||
self.hostname = hostname
|
||||
if port is not None:
|
||||
self.port = port
|
||||
self.command = command
|
||||
self.username = username
|
||||
self.keys = keys
|
||||
self.password = password
|
||||
self.agentEndpoint = agentEndpoint
|
||||
if knownHosts is None:
|
||||
knownHosts = self._knownHosts()
|
||||
self.knownHosts = knownHosts
|
||||
|
||||
if ui is None:
|
||||
ui = ConsoleUI(self._opener)
|
||||
self.ui = ui
|
||||
self.tty = tty
|
||||
|
||||
|
||||
def _opener(self):
|
||||
"""
|
||||
Open the tty if possible, otherwise give back a file-like object from
|
||||
which C{b"no"} can be read.
|
||||
|
||||
For use as the opener argument to L{ConsoleUI}.
|
||||
"""
|
||||
try:
|
||||
return self.tty.open("r+")
|
||||
except:
|
||||
# Give back a file-like object from which can be read a byte string
|
||||
# that KnownHostsFile recognizes as rejecting some option (b"no").
|
||||
return _ReadFile(b"no")
|
||||
|
||||
|
||||
@classmethod
|
||||
def _knownHosts(cls):
|
||||
"""
|
||||
@return: A L{KnownHostsFile} instance pointed at the user's personal
|
||||
I{known hosts} file.
|
||||
@type: L{KnownHostsFile}
|
||||
"""
|
||||
return KnownHostsFile.fromPath(FilePath(expanduser(cls._KNOWN_HOSTS)))
|
||||
|
||||
|
||||
def secureConnection(self):
|
||||
"""
|
||||
Create and return a new SSH connection which has been secured and on
|
||||
which authentication has already happened.
|
||||
|
||||
@return: A L{Deferred} which fires with the ready-to-use connection or
|
||||
with a failure if something prevents the connection from being
|
||||
setup, secured, or authenticated.
|
||||
"""
|
||||
protocol = _CommandTransport(self)
|
||||
ready = protocol.connectionReady
|
||||
|
||||
sshClient = TCP4ClientEndpoint(self.reactor, self.hostname, self.port)
|
||||
|
||||
d = connectProtocol(sshClient, protocol)
|
||||
d.addCallback(lambda ignored: ready)
|
||||
return d
|
||||
|
||||
|
||||
def cleanupConnection(self, connection, immediate):
|
||||
"""
|
||||
Clean up the connection by closing it. The command running on the
|
||||
endpoint has ended so the connection is no longer needed.
|
||||
|
||||
@param connection: The L{SSHConnection} to close.
|
||||
@type connection: L{SSHConnection}
|
||||
|
||||
@param immediate: Whether to close connection immediately.
|
||||
@type immediate: L{bool}.
|
||||
"""
|
||||
if immediate:
|
||||
# We're assuming the underlying connection is a ITCPTransport,
|
||||
# which is what the current implementation is restricted to:
|
||||
connection.transport.transport.abortConnection()
|
||||
else:
|
||||
connection.transport.loseConnection()
|
||||
|
||||
|
||||
|
||||
@implementer(_ISSHConnectionCreator)
|
||||
class _ExistingConnectionHelper(object):
|
||||
"""
|
||||
L{_ExistingConnectionHelper} implements L{_ISSHConnectionCreator} by
|
||||
handing out an existing SSH connection which is supplied to its
|
||||
initializer.
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
"""
|
||||
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
|
||||
C{connection} parameter.
|
||||
"""
|
||||
self.connection = connection
|
||||
|
||||
|
||||
def secureConnection(self):
|
||||
"""
|
||||
@return: A L{Deferred} that fires synchronously with the
|
||||
already-established connection object.
|
||||
"""
|
||||
return succeed(self.connection)
|
||||
|
||||
|
||||
def cleanupConnection(self, connection, immediate):
|
||||
"""
|
||||
Do not do any cleanup on the connection. Leave that responsibility to
|
||||
whatever code created it in the first place.
|
||||
|
||||
@param connection: The L{SSHConnection} which will not be modified in
|
||||
any way.
|
||||
@type connection: L{SSHConnection}
|
||||
|
||||
@param immediate: An argument which will be ignored.
|
||||
@type immediate: L{bool}.
|
||||
"""
|
@ -0,0 +1,102 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
An error to represent bad things happening in Conch.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
from twisted.cred.error import UnauthorizedLogin
|
||||
|
||||
|
||||
|
||||
class ConchError(Exception):
|
||||
def __init__(self, value, data = None):
|
||||
Exception.__init__(self, value, data)
|
||||
self.value = value
|
||||
self.data = data
|
||||
|
||||
|
||||
|
||||
class NotEnoughAuthentication(Exception):
|
||||
"""
|
||||
This is thrown if the authentication is valid, but is not enough to
|
||||
successfully verify the user. i.e. don't retry this type of
|
||||
authentication, try another one.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class ValidPublicKey(UnauthorizedLogin):
|
||||
"""
|
||||
Raised by public key checkers when they receive public key credentials
|
||||
that don't contain a signature at all, but are valid in every other way.
|
||||
(e.g. the public key matches one in the user's authorized_keys file).
|
||||
|
||||
Protocol code (eg
|
||||
L{SSHUserAuthServer<twisted.conch.ssh.userauth.SSHUserAuthServer>}) which
|
||||
attempts to log in using
|
||||
L{ISSHPrivateKey<twisted.cred.credentials.ISSHPrivateKey>} credentials
|
||||
should be prepared to handle a failure of this type by telling the user to
|
||||
re-authenticate using the same key and to include a signature with the new
|
||||
attempt.
|
||||
|
||||
See U{http://www.ietf.org/rfc/rfc4252.txt} section 7 for more details.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class IgnoreAuthentication(Exception):
|
||||
"""
|
||||
This is thrown to let the UserAuthServer know it doesn't need to handle the
|
||||
authentication anymore.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class MissingKeyStoreError(Exception):
|
||||
"""
|
||||
Raised if an SSHAgentServer starts receiving data without its factory
|
||||
providing a keys dict on which to read/write key data.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class UserRejectedKey(Exception):
|
||||
"""
|
||||
The user interactively rejected a key.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class InvalidEntry(Exception):
|
||||
"""
|
||||
An entry in a known_hosts file could not be interpreted as a valid entry.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class HostKeyChanged(Exception):
|
||||
"""
|
||||
The host key of a remote host has changed.
|
||||
|
||||
@ivar offendingEntry: The entry which contains the persistent host key that
|
||||
disagrees with the given host key.
|
||||
|
||||
@type offendingEntry: L{twisted.conch.interfaces.IKnownHostEntry}
|
||||
|
||||
@ivar path: a reference to the known_hosts file that the offending entry
|
||||
was loaded from
|
||||
|
||||
@type path: L{twisted.python.filepath.FilePath}
|
||||
|
||||
@ivar lineno: The line number of the offending entry in the given path.
|
||||
|
||||
@type lineno: L{int}
|
||||
"""
|
||||
def __init__(self, offendingEntry, path, lineno):
|
||||
Exception.__init__(self)
|
||||
self.offendingEntry = offendingEntry
|
||||
self.path = path
|
||||
self.lineno = lineno
|
@ -0,0 +1,16 @@
|
||||
"""
|
||||
Insults: a replacement for Curses/S-Lang.
|
||||
|
||||
Very basic at the moment."""
|
||||
|
||||
from twisted.python import deprecate, versions
|
||||
|
||||
deprecate.deprecatedModuleAttribute(
|
||||
versions.Version("Twisted", 10, 1, 0),
|
||||
"Please use twisted.conch.insults.helper instead.",
|
||||
__name__, "colors")
|
||||
|
||||
deprecate.deprecatedModuleAttribute(
|
||||
versions.Version("Twisted", 10, 1, 0),
|
||||
"Please use twisted.conch.insults.insults instead.",
|
||||
__name__, "client")
|
@ -0,0 +1,138 @@
|
||||
"""
|
||||
You don't really want to use this module. Try insults.py instead.
|
||||
"""
|
||||
|
||||
from twisted.internet import protocol
|
||||
|
||||
class InsultsClient(protocol.Protocol):
|
||||
|
||||
escapeTimeout = 0.2
|
||||
|
||||
def __init__(self):
|
||||
self.width = self.height = None
|
||||
self.xpos = self.ypos = 0
|
||||
self.commandQueue = []
|
||||
self.inEscape = ''
|
||||
|
||||
def setSize(self, width, height):
|
||||
call = 0
|
||||
if self.width:
|
||||
call = 1
|
||||
self.width = width
|
||||
self.height = height
|
||||
if call:
|
||||
self.windowSizeChanged()
|
||||
|
||||
def dataReceived(self, data):
|
||||
from twisted.internet import reactor
|
||||
for ch in data:
|
||||
if ch == '\x1b':
|
||||
if self.inEscape:
|
||||
self.keyReceived(ch)
|
||||
self.inEscape = ''
|
||||
else:
|
||||
self.inEscape = ch
|
||||
self.escapeCall = reactor.callLater(self.escapeTimeout,
|
||||
self.endEscape)
|
||||
elif ch in 'ABCD' and self.inEscape:
|
||||
self.inEscape = ''
|
||||
self.escapeCall.cancel()
|
||||
if ch == 'A':
|
||||
self.keyReceived('<Up>')
|
||||
elif ch == 'B':
|
||||
self.keyReceived('<Down>')
|
||||
elif ch == 'C':
|
||||
self.keyReceived('<Right>')
|
||||
elif ch == 'D':
|
||||
self.keyReceived('<Left>')
|
||||
elif self.inEscape:
|
||||
self.inEscape += ch
|
||||
else:
|
||||
self.keyReceived(ch)
|
||||
|
||||
def endEscape(self):
|
||||
ch = self.inEscape
|
||||
self.inEscape = ''
|
||||
self.keyReceived(ch)
|
||||
|
||||
def initScreen(self):
|
||||
self.transport.write('\x1b=\x1b[?1h')
|
||||
|
||||
def gotoXY(self, x, y):
|
||||
"""Go to a position on the screen.
|
||||
"""
|
||||
self.xpos = x
|
||||
self.ypos = y
|
||||
self.commandQueue.append(('gotoxy', x, y))
|
||||
|
||||
def writeCh(self, ch):
|
||||
"""Write a character to the screen. If we're at the end of the row,
|
||||
ignore the write.
|
||||
"""
|
||||
if self.xpos < self.width - 1:
|
||||
self.commandQueue.append(('write', ch))
|
||||
self.xpos += 1
|
||||
|
||||
def writeStr(self, s):
|
||||
"""Write a string to the screen. This does not wrap a the edge of the
|
||||
screen, and stops at \\r and \\n.
|
||||
"""
|
||||
s = s[:self.width-self.xpos]
|
||||
if '\n' in s:
|
||||
s=s[:s.find('\n')]
|
||||
if '\r' in s:
|
||||
s=s[:s.find('\r')]
|
||||
self.commandQueue.append(('write', s))
|
||||
self.xpos += len(s)
|
||||
|
||||
def eraseToLine(self):
|
||||
"""Erase from the current position to the end of the line.
|
||||
"""
|
||||
self.commandQueue.append(('eraseeol',))
|
||||
|
||||
def eraseToScreen(self):
|
||||
"""Erase from the current position to the end of the screen.
|
||||
"""
|
||||
self.commandQueue.append(('eraseeos',))
|
||||
|
||||
def clearScreen(self):
|
||||
"""Clear the screen, and return the cursor to 0, 0.
|
||||
"""
|
||||
self.commandQueue = [('cls',)]
|
||||
self.xpos = self.ypos = 0
|
||||
|
||||
def setAttributes(self, *attrs):
|
||||
"""Set the attributes for drawing on the screen.
|
||||
"""
|
||||
self.commandQueue.append(('attributes', attrs))
|
||||
|
||||
def refresh(self):
|
||||
"""Redraw the screen.
|
||||
"""
|
||||
redraw = ''
|
||||
for command in self.commandQueue:
|
||||
if command[0] == 'gotoxy':
|
||||
redraw += '\x1b[%i;%iH' % (command[2]+1, command[1]+1)
|
||||
elif command[0] == 'write':
|
||||
redraw += command[1]
|
||||
elif command[0] == 'eraseeol':
|
||||
redraw += '\x1b[0K'
|
||||
elif command[0] == 'eraseeos':
|
||||
redraw += '\x1b[OJ'
|
||||
elif command[0] == 'cls':
|
||||
redraw += '\x1b[H\x1b[J'
|
||||
elif command[0] == 'attributes':
|
||||
redraw += '\x1b[%sm' % ';'.join(map(str, command[1]))
|
||||
else:
|
||||
print command
|
||||
self.commandQueue = []
|
||||
self.transport.write(redraw)
|
||||
|
||||
def windowSizeChanged(self):
|
||||
"""Called when the size of the window changes.
|
||||
Might want to redraw the screen here, or something.
|
||||
"""
|
||||
|
||||
def keyReceived(self, key):
|
||||
"""Called when the user hits a key.
|
||||
"""
|
@ -0,0 +1,29 @@
|
||||
"""
|
||||
You don't really want to use this module. Try helper.py instead.
|
||||
"""
|
||||
|
||||
CLEAR = 0
|
||||
BOLD = 1
|
||||
DIM = 2
|
||||
ITALIC = 3
|
||||
UNDERSCORE = 4
|
||||
BLINK_SLOW = 5
|
||||
BLINK_FAST = 6
|
||||
REVERSE = 7
|
||||
CONCEALED = 8
|
||||
FG_BLACK = 30
|
||||
FG_RED = 31
|
||||
FG_GREEN = 32
|
||||
FG_YELLOW = 33
|
||||
FG_BLUE = 34
|
||||
FG_MAGENTA = 35
|
||||
FG_CYAN = 36
|
||||
FG_WHITE = 37
|
||||
BG_BLACK = 40
|
||||
BG_RED = 41
|
||||
BG_GREEN = 42
|
||||
BG_YELLOW = 43
|
||||
BG_BLUE = 44
|
||||
BG_MAGENTA = 45
|
||||
BG_CYAN = 46
|
||||
BG_WHITE = 47
|
@ -0,0 +1,462 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_helper -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Partial in-memory terminal emulator
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
import re, string
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer, protocol, reactor
|
||||
from twisted.python import log, _textattributes
|
||||
from twisted.python.deprecate import deprecated, deprecatedModuleAttribute
|
||||
from twisted.python.versions import Version
|
||||
from twisted.conch.insults import insults
|
||||
|
||||
FOREGROUND = 30
|
||||
BACKGROUND = 40
|
||||
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, N_COLORS = range(9)
|
||||
|
||||
|
||||
|
||||
class _FormattingState(_textattributes._FormattingStateMixin):
|
||||
"""
|
||||
Represents the formatting state/attributes of a single character.
|
||||
|
||||
Character set, intensity, underlinedness, blinkitude, video
|
||||
reversal, as well as foreground and background colors made up a
|
||||
character's attributes.
|
||||
"""
|
||||
compareAttributes = (
|
||||
'charset', 'bold', 'underline', 'blink', 'reverseVideo', 'foreground',
|
||||
'background', '_subtracting')
|
||||
|
||||
|
||||
def __init__(self, charset=insults.G0, bold=False, underline=False,
|
||||
blink=False, reverseVideo=False, foreground=WHITE,
|
||||
background=BLACK, _subtracting=False):
|
||||
self.charset = charset
|
||||
self.bold = bold
|
||||
self.underline = underline
|
||||
self.blink = blink
|
||||
self.reverseVideo = reverseVideo
|
||||
self.foreground = foreground
|
||||
self.background = background
|
||||
self._subtracting = _subtracting
|
||||
|
||||
|
||||
@deprecated(Version('Twisted', 13, 1, 0))
|
||||
def wantOne(self, **kw):
|
||||
"""
|
||||
Add a character attribute to a copy of this formatting state.
|
||||
|
||||
@param **kw: An optional attribute name and value can be provided with
|
||||
a keyword argument.
|
||||
|
||||
@return: A formatting state instance with the new attribute.
|
||||
|
||||
@see: L{DefaultFormattingState._withAttribute}.
|
||||
"""
|
||||
k, v = kw.popitem()
|
||||
return self._withAttribute(k, v)
|
||||
|
||||
|
||||
def toVT102(self):
|
||||
# Spit out a vt102 control sequence that will set up
|
||||
# all the attributes set here. Except charset.
|
||||
attrs = []
|
||||
if self._subtracting:
|
||||
attrs.append(0)
|
||||
if self.bold:
|
||||
attrs.append(insults.BOLD)
|
||||
if self.underline:
|
||||
attrs.append(insults.UNDERLINE)
|
||||
if self.blink:
|
||||
attrs.append(insults.BLINK)
|
||||
if self.reverseVideo:
|
||||
attrs.append(insults.REVERSE_VIDEO)
|
||||
if self.foreground != WHITE:
|
||||
attrs.append(FOREGROUND + self.foreground)
|
||||
if self.background != BLACK:
|
||||
attrs.append(BACKGROUND + self.background)
|
||||
if attrs:
|
||||
return '\x1b[' + ';'.join(map(str, attrs)) + 'm'
|
||||
return ''
|
||||
|
||||
CharacterAttribute = _FormattingState
|
||||
|
||||
deprecatedModuleAttribute(
|
||||
Version('Twisted', 13, 1, 0),
|
||||
'Use twisted.conch.insults.text.assembleFormattedText instead.',
|
||||
'twisted.conch.insults.helper',
|
||||
'CharacterAttribute')
|
||||
|
||||
|
||||
|
||||
# XXX - need to support scroll regions and scroll history
|
||||
@implementer(insults.ITerminalTransport)
|
||||
class TerminalBuffer(protocol.Protocol):
|
||||
"""
|
||||
An in-memory terminal emulator.
|
||||
"""
|
||||
for keyID in ('UP_ARROW', 'DOWN_ARROW', 'RIGHT_ARROW', 'LEFT_ARROW',
|
||||
'HOME', 'INSERT', 'DELETE', 'END', 'PGUP', 'PGDN',
|
||||
'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9',
|
||||
'F10', 'F11', 'F12'):
|
||||
exec '%s = object()' % (keyID,)
|
||||
|
||||
TAB = '\t'
|
||||
BACKSPACE = '\x7f'
|
||||
|
||||
width = 80
|
||||
height = 24
|
||||
|
||||
fill = ' '
|
||||
void = object()
|
||||
|
||||
def getCharacter(self, x, y):
|
||||
return self.lines[y][x]
|
||||
|
||||
def connectionMade(self):
|
||||
self.reset()
|
||||
|
||||
def write(self, bytes):
|
||||
"""
|
||||
Add the given printable bytes to the terminal.
|
||||
|
||||
Line feeds in C{bytes} will be replaced with carriage return / line
|
||||
feed pairs.
|
||||
"""
|
||||
for b in bytes.replace('\n', '\r\n'):
|
||||
self.insertAtCursor(b)
|
||||
|
||||
def _currentFormattingState(self):
|
||||
return _FormattingState(self.activeCharset, **self.graphicRendition)
|
||||
|
||||
def insertAtCursor(self, b):
|
||||
"""
|
||||
Add one byte to the terminal at the cursor and make consequent state
|
||||
updates.
|
||||
|
||||
If b is a carriage return, move the cursor to the beginning of the
|
||||
current row.
|
||||
|
||||
If b is a line feed, move the cursor to the next row or scroll down if
|
||||
the cursor is already in the last row.
|
||||
|
||||
Otherwise, if b is printable, put it at the cursor position (inserting
|
||||
or overwriting as dictated by the current mode) and move the cursor.
|
||||
"""
|
||||
if b == '\r':
|
||||
self.x = 0
|
||||
elif b == '\n':
|
||||
self._scrollDown()
|
||||
elif b in string.printable:
|
||||
if self.x >= self.width:
|
||||
self.nextLine()
|
||||
ch = (b, self._currentFormattingState())
|
||||
if self.modes.get(insults.modes.IRM):
|
||||
self.lines[self.y][self.x:self.x] = [ch]
|
||||
self.lines[self.y].pop()
|
||||
else:
|
||||
self.lines[self.y][self.x] = ch
|
||||
self.x += 1
|
||||
|
||||
def _emptyLine(self, width):
|
||||
return [(self.void, self._currentFormattingState())
|
||||
for i in xrange(width)]
|
||||
|
||||
def _scrollDown(self):
|
||||
self.y += 1
|
||||
if self.y >= self.height:
|
||||
self.y -= 1
|
||||
del self.lines[0]
|
||||
self.lines.append(self._emptyLine(self.width))
|
||||
|
||||
def _scrollUp(self):
|
||||
self.y -= 1
|
||||
if self.y < 0:
|
||||
self.y = 0
|
||||
del self.lines[-1]
|
||||
self.lines.insert(0, self._emptyLine(self.width))
|
||||
|
||||
def cursorUp(self, n=1):
|
||||
self.y = max(0, self.y - n)
|
||||
|
||||
def cursorDown(self, n=1):
|
||||
self.y = min(self.height - 1, self.y + n)
|
||||
|
||||
def cursorBackward(self, n=1):
|
||||
self.x = max(0, self.x - n)
|
||||
|
||||
def cursorForward(self, n=1):
|
||||
self.x = min(self.width, self.x + n)
|
||||
|
||||
def cursorPosition(self, column, line):
|
||||
self.x = column
|
||||
self.y = line
|
||||
|
||||
def cursorHome(self):
|
||||
self.x = self.home.x
|
||||
self.y = self.home.y
|
||||
|
||||
def index(self):
|
||||
self._scrollDown()
|
||||
|
||||
def reverseIndex(self):
|
||||
self._scrollUp()
|
||||
|
||||
def nextLine(self):
|
||||
"""
|
||||
Update the cursor position attributes and scroll down if appropriate.
|
||||
"""
|
||||
self.x = 0
|
||||
self._scrollDown()
|
||||
|
||||
def saveCursor(self):
|
||||
self._savedCursor = (self.x, self.y)
|
||||
|
||||
def restoreCursor(self):
|
||||
self.x, self.y = self._savedCursor
|
||||
del self._savedCursor
|
||||
|
||||
def setModes(self, modes):
|
||||
for m in modes:
|
||||
self.modes[m] = True
|
||||
|
||||
def resetModes(self, modes):
|
||||
for m in modes:
|
||||
try:
|
||||
del self.modes[m]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def setPrivateModes(self, modes):
|
||||
"""
|
||||
Enable the given modes.
|
||||
|
||||
Track which modes have been enabled so that the implementations of
|
||||
other L{insults.ITerminalTransport} methods can be properly implemented
|
||||
to respect these settings.
|
||||
|
||||
@see: L{resetPrivateModes}
|
||||
@see: L{insults.ITerminalTransport.setPrivateModes}
|
||||
"""
|
||||
for m in modes:
|
||||
self.privateModes[m] = True
|
||||
|
||||
|
||||
def resetPrivateModes(self, modes):
|
||||
"""
|
||||
Disable the given modes.
|
||||
|
||||
@see: L{setPrivateModes}
|
||||
@see: L{insults.ITerminalTransport.resetPrivateModes}
|
||||
"""
|
||||
for m in modes:
|
||||
try:
|
||||
del self.privateModes[m]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
def applicationKeypadMode(self):
|
||||
self.keypadMode = 'app'
|
||||
|
||||
def numericKeypadMode(self):
|
||||
self.keypadMode = 'num'
|
||||
|
||||
def selectCharacterSet(self, charSet, which):
|
||||
self.charsets[which] = charSet
|
||||
|
||||
def shiftIn(self):
|
||||
self.activeCharset = insults.G0
|
||||
|
||||
def shiftOut(self):
|
||||
self.activeCharset = insults.G1
|
||||
|
||||
def singleShift2(self):
|
||||
oldActiveCharset = self.activeCharset
|
||||
self.activeCharset = insults.G2
|
||||
f = self.insertAtCursor
|
||||
def insertAtCursor(b):
|
||||
f(b)
|
||||
del self.insertAtCursor
|
||||
self.activeCharset = oldActiveCharset
|
||||
self.insertAtCursor = insertAtCursor
|
||||
|
||||
def singleShift3(self):
|
||||
oldActiveCharset = self.activeCharset
|
||||
self.activeCharset = insults.G3
|
||||
f = self.insertAtCursor
|
||||
def insertAtCursor(b):
|
||||
f(b)
|
||||
del self.insertAtCursor
|
||||
self.activeCharset = oldActiveCharset
|
||||
self.insertAtCursor = insertAtCursor
|
||||
|
||||
def selectGraphicRendition(self, *attributes):
|
||||
for a in attributes:
|
||||
if a == insults.NORMAL:
|
||||
self.graphicRendition = {
|
||||
'bold': False,
|
||||
'underline': False,
|
||||
'blink': False,
|
||||
'reverseVideo': False,
|
||||
'foreground': WHITE,
|
||||
'background': BLACK}
|
||||
elif a == insults.BOLD:
|
||||
self.graphicRendition['bold'] = True
|
||||
elif a == insults.UNDERLINE:
|
||||
self.graphicRendition['underline'] = True
|
||||
elif a == insults.BLINK:
|
||||
self.graphicRendition['blink'] = True
|
||||
elif a == insults.REVERSE_VIDEO:
|
||||
self.graphicRendition['reverseVideo'] = True
|
||||
else:
|
||||
try:
|
||||
v = int(a)
|
||||
except ValueError:
|
||||
log.msg("Unknown graphic rendition attribute: " + repr(a))
|
||||
else:
|
||||
if FOREGROUND <= v <= FOREGROUND + N_COLORS:
|
||||
self.graphicRendition['foreground'] = v - FOREGROUND
|
||||
elif BACKGROUND <= v <= BACKGROUND + N_COLORS:
|
||||
self.graphicRendition['background'] = v - BACKGROUND
|
||||
else:
|
||||
log.msg("Unknown graphic rendition attribute: " + repr(a))
|
||||
|
||||
def eraseLine(self):
|
||||
self.lines[self.y] = self._emptyLine(self.width)
|
||||
|
||||
def eraseToLineEnd(self):
|
||||
width = self.width - self.x
|
||||
self.lines[self.y][self.x:] = self._emptyLine(width)
|
||||
|
||||
def eraseToLineBeginning(self):
|
||||
self.lines[self.y][:self.x + 1] = self._emptyLine(self.x + 1)
|
||||
|
||||
def eraseDisplay(self):
|
||||
self.lines = [self._emptyLine(self.width) for i in xrange(self.height)]
|
||||
|
||||
def eraseToDisplayEnd(self):
|
||||
self.eraseToLineEnd()
|
||||
height = self.height - self.y - 1
|
||||
self.lines[self.y + 1:] = [self._emptyLine(self.width) for i in range(height)]
|
||||
|
||||
def eraseToDisplayBeginning(self):
|
||||
self.eraseToLineBeginning()
|
||||
self.lines[:self.y] = [self._emptyLine(self.width) for i in range(self.y)]
|
||||
|
||||
def deleteCharacter(self, n=1):
|
||||
del self.lines[self.y][self.x:self.x+n]
|
||||
self.lines[self.y].extend(self._emptyLine(min(self.width - self.x, n)))
|
||||
|
||||
def insertLine(self, n=1):
|
||||
self.lines[self.y:self.y] = [self._emptyLine(self.width) for i in range(n)]
|
||||
del self.lines[self.height:]
|
||||
|
||||
def deleteLine(self, n=1):
|
||||
del self.lines[self.y:self.y+n]
|
||||
self.lines.extend([self._emptyLine(self.width) for i in range(n)])
|
||||
|
||||
def reportCursorPosition(self):
|
||||
return (self.x, self.y)
|
||||
|
||||
def reset(self):
|
||||
self.home = insults.Vector(0, 0)
|
||||
self.x = self.y = 0
|
||||
self.modes = {}
|
||||
self.privateModes = {}
|
||||
self.setPrivateModes([insults.privateModes.AUTO_WRAP,
|
||||
insults.privateModes.CURSOR_MODE])
|
||||
self.numericKeypad = 'app'
|
||||
self.activeCharset = insults.G0
|
||||
self.graphicRendition = {
|
||||
'bold': False,
|
||||
'underline': False,
|
||||
'blink': False,
|
||||
'reverseVideo': False,
|
||||
'foreground': WHITE,
|
||||
'background': BLACK}
|
||||
self.charsets = {
|
||||
insults.G0: insults.CS_US,
|
||||
insults.G1: insults.CS_US,
|
||||
insults.G2: insults.CS_ALTERNATE,
|
||||
insults.G3: insults.CS_ALTERNATE_SPECIAL}
|
||||
self.eraseDisplay()
|
||||
|
||||
def unhandledControlSequence(self, buf):
|
||||
print 'Could not handle', repr(buf)
|
||||
|
||||
def __str__(self):
|
||||
lines = []
|
||||
for L in self.lines:
|
||||
buf = []
|
||||
length = 0
|
||||
for (ch, attr) in L:
|
||||
if ch is not self.void:
|
||||
buf.append(ch)
|
||||
length = len(buf)
|
||||
else:
|
||||
buf.append(self.fill)
|
||||
lines.append(''.join(buf[:length]))
|
||||
return '\n'.join(lines)
|
||||
|
||||
class ExpectationTimeout(Exception):
|
||||
pass
|
||||
|
||||
class ExpectableBuffer(TerminalBuffer):
|
||||
_mark = 0
|
||||
|
||||
def connectionMade(self):
|
||||
TerminalBuffer.connectionMade(self)
|
||||
self._expecting = []
|
||||
|
||||
def write(self, bytes):
|
||||
TerminalBuffer.write(self, bytes)
|
||||
self._checkExpected()
|
||||
|
||||
def cursorHome(self):
|
||||
TerminalBuffer.cursorHome(self)
|
||||
self._mark = 0
|
||||
|
||||
def _timeoutExpected(self, d):
|
||||
d.errback(ExpectationTimeout())
|
||||
self._checkExpected()
|
||||
|
||||
def _checkExpected(self):
|
||||
s = str(self)[self._mark:]
|
||||
while self._expecting:
|
||||
expr, timer, deferred = self._expecting[0]
|
||||
if timer and not timer.active():
|
||||
del self._expecting[0]
|
||||
continue
|
||||
for match in expr.finditer(s):
|
||||
if timer:
|
||||
timer.cancel()
|
||||
del self._expecting[0]
|
||||
self._mark += match.end()
|
||||
s = s[match.end():]
|
||||
deferred.callback(match)
|
||||
break
|
||||
else:
|
||||
return
|
||||
|
||||
def expect(self, expression, timeout=None, scheduler=reactor):
|
||||
d = defer.Deferred()
|
||||
timer = None
|
||||
if timeout:
|
||||
timer = scheduler.callLater(timeout, self._timeoutExpected, d)
|
||||
self._expecting.append((re.compile(expression), timer, d))
|
||||
self._checkExpected()
|
||||
return d
|
||||
|
||||
__all__ = [
|
||||
'CharacterAttribute', 'TerminalBuffer', 'ExpectableBuffer']
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,175 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_text -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Character attribute manipulation API.
|
||||
|
||||
This module provides a domain-specific language (using Python syntax)
|
||||
for the creation of text with additional display attributes associated
|
||||
with it. It is intended as an alternative to manually building up
|
||||
strings containing ECMA 48 character attribute control codes. It
|
||||
currently supports foreground and background colors (black, red,
|
||||
green, yellow, blue, magenta, cyan, and white), intensity selection,
|
||||
underlining, blinking and reverse video. Character set selection
|
||||
support is planned.
|
||||
|
||||
Character attributes are specified by using two Python operations:
|
||||
attribute lookup and indexing. For example, the string \"Hello
|
||||
world\" with red foreground and all other attributes set to their
|
||||
defaults, assuming the name twisted.conch.insults.text.attributes has
|
||||
been imported and bound to the name \"A\" (with the statement C{from
|
||||
twisted.conch.insults.text import attributes as A}, for example) one
|
||||
uses this expression::
|
||||
|
||||
A.fg.red[\"Hello world\"]
|
||||
|
||||
Other foreground colors are set by substituting their name for
|
||||
\"red\". To set both a foreground and a background color, this
|
||||
expression is used::
|
||||
|
||||
A.fg.red[A.bg.green[\"Hello world\"]]
|
||||
|
||||
Note that either A.bg.green can be nested within A.fg.red or vice
|
||||
versa. Also note that multiple items can be nested within a single
|
||||
index operation by separating them with commas::
|
||||
|
||||
A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]
|
||||
|
||||
Other character attributes are set in a similar fashion. To specify a
|
||||
blinking version of the previous expression::
|
||||
|
||||
A.blink[A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]]
|
||||
|
||||
C{A.reverseVideo}, C{A.underline}, and C{A.bold} are also valid.
|
||||
|
||||
A third operation is actually supported: unary negation. This turns
|
||||
off an attribute when an enclosing expression would otherwise have
|
||||
caused it to be on. For example::
|
||||
|
||||
A.underline[A.fg.red[\"Hello\", -A.underline[\" world\"]]]
|
||||
|
||||
A formatting structure can then be serialized into a string containing the
|
||||
necessary VT102 control codes with L{assembleFormattedText}.
|
||||
|
||||
@see: L{twisted.conch.insults.text.attributes}
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
from twisted.conch.insults import helper, insults
|
||||
from twisted.python import _textattributes
|
||||
from twisted.python.deprecate import deprecatedModuleAttribute
|
||||
from twisted.python.versions import Version
|
||||
|
||||
|
||||
|
||||
flatten = _textattributes.flatten
|
||||
|
||||
deprecatedModuleAttribute(
|
||||
Version('Twisted', 13, 1, 0),
|
||||
'Use twisted.conch.insults.text.assembleFormattedText instead.',
|
||||
'twisted.conch.insults.text',
|
||||
'flatten')
|
||||
|
||||
_TEXT_COLORS = {
|
||||
'black': helper.BLACK,
|
||||
'red': helper.RED,
|
||||
'green': helper.GREEN,
|
||||
'yellow': helper.YELLOW,
|
||||
'blue': helper.BLUE,
|
||||
'magenta': helper.MAGENTA,
|
||||
'cyan': helper.CYAN,
|
||||
'white': helper.WHITE}
|
||||
|
||||
|
||||
|
||||
class _CharacterAttributes(_textattributes.CharacterAttributesMixin):
|
||||
"""
|
||||
Factory for character attributes, including foreground and background color
|
||||
and non-color attributes such as bold, reverse video and underline.
|
||||
|
||||
Character attributes are applied to actual text by using object
|
||||
indexing-syntax (C{obj['abc']}) after accessing a factory attribute, for
|
||||
example::
|
||||
|
||||
attributes.bold['Some text']
|
||||
|
||||
These can be nested to mix attributes::
|
||||
|
||||
attributes.bold[attributes.underline['Some text']]
|
||||
|
||||
And multiple values can be passed::
|
||||
|
||||
attributes.normal[attributes.bold['Some'], ' text']
|
||||
|
||||
Non-color attributes can be accessed by attribute name, available
|
||||
attributes are:
|
||||
|
||||
- bold
|
||||
- blink
|
||||
- reverseVideo
|
||||
- underline
|
||||
|
||||
Available colors are:
|
||||
|
||||
0. black
|
||||
1. red
|
||||
2. green
|
||||
3. yellow
|
||||
4. blue
|
||||
5. magenta
|
||||
6. cyan
|
||||
7. white
|
||||
|
||||
@ivar fg: Foreground colors accessed by attribute name, see above
|
||||
for possible names.
|
||||
|
||||
@ivar bg: Background colors accessed by attribute name, see above
|
||||
for possible names.
|
||||
"""
|
||||
fg = _textattributes._ColorAttribute(
|
||||
_textattributes._ForegroundColorAttr, _TEXT_COLORS)
|
||||
bg = _textattributes._ColorAttribute(
|
||||
_textattributes._BackgroundColorAttr, _TEXT_COLORS)
|
||||
|
||||
attrs = {
|
||||
'bold': insults.BOLD,
|
||||
'blink': insults.BLINK,
|
||||
'underline': insults.UNDERLINE,
|
||||
'reverseVideo': insults.REVERSE_VIDEO}
|
||||
|
||||
|
||||
|
||||
def assembleFormattedText(formatted):
|
||||
"""
|
||||
Assemble formatted text from structured information.
|
||||
|
||||
Currently handled formatting includes: bold, blink, reverse, underline and
|
||||
color codes.
|
||||
|
||||
For example::
|
||||
|
||||
from twisted.conch.insults.text import attributes as A
|
||||
assembleFormattedText(
|
||||
A.normal[A.bold['Time: '], A.fg.lightRed['Now!']])
|
||||
|
||||
Would produce "Time: " in bold formatting, followed by "Now!" with a
|
||||
foreground color of light red and without any additional formatting.
|
||||
|
||||
@param formatted: Structured text and attributes.
|
||||
|
||||
@rtype: C{str}
|
||||
@return: String containing VT102 control sequences that mimic those
|
||||
specified by L{formatted}.
|
||||
|
||||
@see: L{twisted.conch.insults.text.attributes}
|
||||
@since: 13.1
|
||||
"""
|
||||
return _textattributes.flatten(
|
||||
formatted, helper._FormattingState(), 'toVT102')
|
||||
|
||||
|
||||
|
||||
attributes = _CharacterAttributes()
|
||||
|
||||
__all__ = ['attributes', 'flatten']
|
@ -0,0 +1,868 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_window -*-
|
||||
|
||||
"""
|
||||
Simple insults-based widget library
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
import array
|
||||
|
||||
from twisted.conch.insults import insults, helper
|
||||
from twisted.python import text as tptext
|
||||
|
||||
class YieldFocus(Exception):
|
||||
"""Input focus manipulation exception
|
||||
"""
|
||||
|
||||
class BoundedTerminalWrapper(object):
|
||||
def __init__(self, terminal, width, height, xoff, yoff):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.xoff = xoff
|
||||
self.yoff = yoff
|
||||
self.terminal = terminal
|
||||
self.cursorForward = terminal.cursorForward
|
||||
self.selectCharacterSet = terminal.selectCharacterSet
|
||||
self.selectGraphicRendition = terminal.selectGraphicRendition
|
||||
self.saveCursor = terminal.saveCursor
|
||||
self.restoreCursor = terminal.restoreCursor
|
||||
|
||||
def cursorPosition(self, x, y):
|
||||
return self.terminal.cursorPosition(
|
||||
self.xoff + min(self.width, x),
|
||||
self.yoff + min(self.height, y)
|
||||
)
|
||||
|
||||
def cursorHome(self):
|
||||
return self.terminal.cursorPosition(
|
||||
self.xoff, self.yoff)
|
||||
|
||||
def write(self, bytes):
|
||||
return self.terminal.write(bytes)
|
||||
|
||||
class Widget(object):
|
||||
focused = False
|
||||
parent = None
|
||||
dirty = False
|
||||
width = height = None
|
||||
|
||||
def repaint(self):
|
||||
if not self.dirty:
|
||||
self.dirty = True
|
||||
if self.parent is not None and not self.parent.dirty:
|
||||
self.parent.repaint()
|
||||
|
||||
def filthy(self):
|
||||
self.dirty = True
|
||||
|
||||
def redraw(self, width, height, terminal):
|
||||
self.filthy()
|
||||
self.draw(width, height, terminal)
|
||||
|
||||
def draw(self, width, height, terminal):
|
||||
if width != self.width or height != self.height or self.dirty:
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.dirty = False
|
||||
self.render(width, height, terminal)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
pass
|
||||
|
||||
def sizeHint(self):
|
||||
return None
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
if keyID == '\t':
|
||||
self.tabReceived(modifier)
|
||||
elif keyID == '\x7f':
|
||||
self.backspaceReceived()
|
||||
elif keyID in insults.FUNCTION_KEYS:
|
||||
self.functionKeyReceived(keyID, modifier)
|
||||
else:
|
||||
self.characterReceived(keyID, modifier)
|
||||
|
||||
def tabReceived(self, modifier):
|
||||
# XXX TODO - Handle shift+tab
|
||||
raise YieldFocus()
|
||||
|
||||
def focusReceived(self):
|
||||
"""Called when focus is being given to this widget.
|
||||
|
||||
May raise YieldFocus is this widget does not want focus.
|
||||
"""
|
||||
self.focused = True
|
||||
self.repaint()
|
||||
|
||||
def focusLost(self):
|
||||
self.focused = False
|
||||
self.repaint()
|
||||
|
||||
def backspaceReceived(self):
|
||||
pass
|
||||
|
||||
def functionKeyReceived(self, keyID, modifier):
|
||||
func = getattr(self, 'func_' + keyID.name, None)
|
||||
if func is not None:
|
||||
func(modifier)
|
||||
|
||||
def characterReceived(self, keyID, modifier):
|
||||
pass
|
||||
|
||||
class ContainerWidget(Widget):
|
||||
"""
|
||||
@ivar focusedChild: The contained widget which currently has
|
||||
focus, or None.
|
||||
"""
|
||||
focusedChild = None
|
||||
focused = False
|
||||
|
||||
def __init__(self):
|
||||
Widget.__init__(self)
|
||||
self.children = []
|
||||
|
||||
def addChild(self, child):
|
||||
assert child.parent is None
|
||||
child.parent = self
|
||||
self.children.append(child)
|
||||
if self.focusedChild is None and self.focused:
|
||||
try:
|
||||
child.focusReceived()
|
||||
except YieldFocus:
|
||||
pass
|
||||
else:
|
||||
self.focusedChild = child
|
||||
self.repaint()
|
||||
|
||||
def remChild(self, child):
|
||||
assert child.parent is self
|
||||
child.parent = None
|
||||
self.children.remove(child)
|
||||
self.repaint()
|
||||
|
||||
def filthy(self):
|
||||
for ch in self.children:
|
||||
ch.filthy()
|
||||
Widget.filthy(self)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
for ch in self.children:
|
||||
ch.draw(width, height, terminal)
|
||||
|
||||
def changeFocus(self):
|
||||
self.repaint()
|
||||
|
||||
if self.focusedChild is not None:
|
||||
self.focusedChild.focusLost()
|
||||
focusedChild = self.focusedChild
|
||||
self.focusedChild = None
|
||||
try:
|
||||
curFocus = self.children.index(focusedChild) + 1
|
||||
except ValueError:
|
||||
raise YieldFocus()
|
||||
else:
|
||||
curFocus = 0
|
||||
while curFocus < len(self.children):
|
||||
try:
|
||||
self.children[curFocus].focusReceived()
|
||||
except YieldFocus:
|
||||
curFocus += 1
|
||||
else:
|
||||
self.focusedChild = self.children[curFocus]
|
||||
return
|
||||
# None of our children wanted focus
|
||||
raise YieldFocus()
|
||||
|
||||
|
||||
def focusReceived(self):
|
||||
self.changeFocus()
|
||||
self.focused = True
|
||||
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
if self.focusedChild is not None:
|
||||
try:
|
||||
self.focusedChild.keystrokeReceived(keyID, modifier)
|
||||
except YieldFocus:
|
||||
self.changeFocus()
|
||||
self.repaint()
|
||||
else:
|
||||
Widget.keystrokeReceived(self, keyID, modifier)
|
||||
|
||||
|
||||
class TopWindow(ContainerWidget):
|
||||
"""
|
||||
A top-level container object which provides focus wrap-around and paint
|
||||
scheduling.
|
||||
|
||||
@ivar painter: A no-argument callable which will be invoked when this
|
||||
widget needs to be redrawn.
|
||||
|
||||
@ivar scheduler: A one-argument callable which will be invoked with a
|
||||
no-argument callable and should arrange for it to invoked at some point in
|
||||
the near future. The no-argument callable will cause this widget and all
|
||||
its children to be redrawn. It is typically beneficial for the no-argument
|
||||
callable to be invoked at the end of handling for whatever event is
|
||||
currently active; for example, it might make sense to call it at the end of
|
||||
L{twisted.conch.insults.insults.ITerminalProtocol.keystrokeReceived}.
|
||||
Note, however, that since calls to this may also be made in response to no
|
||||
apparent event, arrangements should be made for the function to be called
|
||||
even if an event handler such as C{keystrokeReceived} is not on the call
|
||||
stack (eg, using C{reactor.callLater} with a short timeout).
|
||||
"""
|
||||
focused = True
|
||||
|
||||
def __init__(self, painter, scheduler):
|
||||
ContainerWidget.__init__(self)
|
||||
self.painter = painter
|
||||
self.scheduler = scheduler
|
||||
|
||||
_paintCall = None
|
||||
def repaint(self):
|
||||
if self._paintCall is None:
|
||||
self._paintCall = object()
|
||||
self.scheduler(self._paint)
|
||||
ContainerWidget.repaint(self)
|
||||
|
||||
def _paint(self):
|
||||
self._paintCall = None
|
||||
self.painter()
|
||||
|
||||
def changeFocus(self):
|
||||
try:
|
||||
ContainerWidget.changeFocus(self)
|
||||
except YieldFocus:
|
||||
try:
|
||||
ContainerWidget.changeFocus(self)
|
||||
except YieldFocus:
|
||||
pass
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
try:
|
||||
ContainerWidget.keystrokeReceived(self, keyID, modifier)
|
||||
except YieldFocus:
|
||||
self.changeFocus()
|
||||
|
||||
|
||||
class AbsoluteBox(ContainerWidget):
|
||||
def moveChild(self, child, x, y):
|
||||
for n in range(len(self.children)):
|
||||
if self.children[n][0] is child:
|
||||
self.children[n] = (child, x, y)
|
||||
break
|
||||
else:
|
||||
raise ValueError("No such child", child)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
for (ch, x, y) in self.children:
|
||||
wrap = BoundedTerminalWrapper(terminal, width - x, height - y, x, y)
|
||||
ch.draw(width, height, wrap)
|
||||
|
||||
|
||||
class _Box(ContainerWidget):
|
||||
TOP, CENTER, BOTTOM = range(3)
|
||||
|
||||
def __init__(self, gravity=CENTER):
|
||||
ContainerWidget.__init__(self)
|
||||
self.gravity = gravity
|
||||
|
||||
def sizeHint(self):
|
||||
height = 0
|
||||
width = 0
|
||||
for ch in self.children:
|
||||
hint = ch.sizeHint()
|
||||
if hint is None:
|
||||
hint = (None, None)
|
||||
|
||||
if self.variableDimension == 0:
|
||||
if hint[0] is None:
|
||||
width = None
|
||||
elif width is not None:
|
||||
width += hint[0]
|
||||
if hint[1] is None:
|
||||
height = None
|
||||
elif height is not None:
|
||||
height = max(height, hint[1])
|
||||
else:
|
||||
if hint[0] is None:
|
||||
width = None
|
||||
elif width is not None:
|
||||
width = max(width, hint[0])
|
||||
if hint[1] is None:
|
||||
height = None
|
||||
elif height is not None:
|
||||
height += hint[1]
|
||||
|
||||
return width, height
|
||||
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
if not self.children:
|
||||
return
|
||||
|
||||
greedy = 0
|
||||
wants = []
|
||||
for ch in self.children:
|
||||
hint = ch.sizeHint()
|
||||
if hint is None:
|
||||
hint = (None, None)
|
||||
if hint[self.variableDimension] is None:
|
||||
greedy += 1
|
||||
wants.append(hint[self.variableDimension])
|
||||
|
||||
length = (width, height)[self.variableDimension]
|
||||
totalWant = sum([w for w in wants if w is not None])
|
||||
if greedy:
|
||||
leftForGreedy = int((length - totalWant) / greedy)
|
||||
|
||||
widthOffset = heightOffset = 0
|
||||
|
||||
for want, ch in zip(wants, self.children):
|
||||
if want is None:
|
||||
want = leftForGreedy
|
||||
|
||||
subWidth, subHeight = width, height
|
||||
if self.variableDimension == 0:
|
||||
subWidth = want
|
||||
else:
|
||||
subHeight = want
|
||||
|
||||
wrap = BoundedTerminalWrapper(
|
||||
terminal,
|
||||
subWidth,
|
||||
subHeight,
|
||||
widthOffset,
|
||||
heightOffset,
|
||||
)
|
||||
ch.draw(subWidth, subHeight, wrap)
|
||||
if self.variableDimension == 0:
|
||||
widthOffset += want
|
||||
else:
|
||||
heightOffset += want
|
||||
|
||||
|
||||
class HBox(_Box):
|
||||
variableDimension = 0
|
||||
|
||||
class VBox(_Box):
|
||||
variableDimension = 1
|
||||
|
||||
|
||||
class Packer(ContainerWidget):
|
||||
def render(self, width, height, terminal):
|
||||
if not self.children:
|
||||
return
|
||||
|
||||
root = int(len(self.children) ** 0.5 + 0.5)
|
||||
boxes = [VBox() for n in range(root)]
|
||||
for n, ch in enumerate(self.children):
|
||||
boxes[n % len(boxes)].addChild(ch)
|
||||
h = HBox()
|
||||
map(h.addChild, boxes)
|
||||
h.render(width, height, terminal)
|
||||
|
||||
|
||||
class Canvas(Widget):
|
||||
focused = False
|
||||
|
||||
contents = None
|
||||
|
||||
def __init__(self):
|
||||
Widget.__init__(self)
|
||||
self.resize(1, 1)
|
||||
|
||||
def resize(self, width, height):
|
||||
contents = array.array('c', ' ' * width * height)
|
||||
if self.contents is not None:
|
||||
for x in range(min(width, self._width)):
|
||||
for y in range(min(height, self._height)):
|
||||
contents[width * y + x] = self[x, y]
|
||||
self.contents = contents
|
||||
self._width = width
|
||||
self._height = height
|
||||
if self.x >= width:
|
||||
self.x = width - 1
|
||||
if self.y >= height:
|
||||
self.y = height - 1
|
||||
|
||||
def __getitem__(self, (x, y)):
|
||||
return self.contents[(self._width * y) + x]
|
||||
|
||||
def __setitem__(self, (x, y), value):
|
||||
self.contents[(self._width * y) + x] = value
|
||||
|
||||
def clear(self):
|
||||
self.contents = array.array('c', ' ' * len(self.contents))
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
if not width or not height:
|
||||
return
|
||||
|
||||
if width != self._width or height != self._height:
|
||||
self.resize(width, height)
|
||||
for i in range(height):
|
||||
terminal.cursorPosition(0, i)
|
||||
terminal.write(''.join(self.contents[self._width * i:self._width * i + self._width])[:width])
|
||||
|
||||
|
||||
def horizontalLine(terminal, y, left, right):
|
||||
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
|
||||
terminal.cursorPosition(left, y)
|
||||
terminal.write(chr(0161) * (right - left))
|
||||
terminal.selectCharacterSet(insults.CS_US, insults.G0)
|
||||
|
||||
def verticalLine(terminal, x, top, bottom):
|
||||
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
|
||||
for n in xrange(top, bottom):
|
||||
terminal.cursorPosition(x, n)
|
||||
terminal.write(chr(0170))
|
||||
terminal.selectCharacterSet(insults.CS_US, insults.G0)
|
||||
|
||||
|
||||
def rectangle(terminal, (top, left), (width, height)):
|
||||
terminal.selectCharacterSet(insults.CS_DRAWING, insults.G0)
|
||||
|
||||
terminal.cursorPosition(top, left)
|
||||
terminal.write(chr(0154))
|
||||
terminal.write(chr(0161) * (width - 2))
|
||||
terminal.write(chr(0153))
|
||||
for n in range(height - 2):
|
||||
terminal.cursorPosition(left, top + n + 1)
|
||||
terminal.write(chr(0170))
|
||||
terminal.cursorForward(width - 2)
|
||||
terminal.write(chr(0170))
|
||||
terminal.cursorPosition(0, top + height - 1)
|
||||
terminal.write(chr(0155))
|
||||
terminal.write(chr(0161) * (width - 2))
|
||||
terminal.write(chr(0152))
|
||||
|
||||
terminal.selectCharacterSet(insults.CS_US, insults.G0)
|
||||
|
||||
class Border(Widget):
|
||||
def __init__(self, containee):
|
||||
Widget.__init__(self)
|
||||
self.containee = containee
|
||||
self.containee.parent = self
|
||||
|
||||
def focusReceived(self):
|
||||
return self.containee.focusReceived()
|
||||
|
||||
def focusLost(self):
|
||||
return self.containee.focusLost()
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
return self.containee.keystrokeReceived(keyID, modifier)
|
||||
|
||||
def sizeHint(self):
|
||||
hint = self.containee.sizeHint()
|
||||
if hint is None:
|
||||
hint = (None, None)
|
||||
if hint[0] is None:
|
||||
x = None
|
||||
else:
|
||||
x = hint[0] + 2
|
||||
if hint[1] is None:
|
||||
y = None
|
||||
else:
|
||||
y = hint[1] + 2
|
||||
return x, y
|
||||
|
||||
def filthy(self):
|
||||
self.containee.filthy()
|
||||
Widget.filthy(self)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
if self.containee.focused:
|
||||
terminal.write('\x1b[31m')
|
||||
rectangle(terminal, (0, 0), (width, height))
|
||||
terminal.write('\x1b[0m')
|
||||
wrap = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
|
||||
self.containee.draw(width - 2, height - 2, wrap)
|
||||
|
||||
|
||||
class Button(Widget):
|
||||
def __init__(self, label, onPress):
|
||||
Widget.__init__(self)
|
||||
self.label = label
|
||||
self.onPress = onPress
|
||||
|
||||
def sizeHint(self):
|
||||
return len(self.label), 1
|
||||
|
||||
def characterReceived(self, keyID, modifier):
|
||||
if keyID == '\r':
|
||||
self.onPress()
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
terminal.cursorPosition(0, 0)
|
||||
if self.focused:
|
||||
terminal.write('\x1b[1m' + self.label + '\x1b[0m')
|
||||
else:
|
||||
terminal.write(self.label)
|
||||
|
||||
class TextInput(Widget):
|
||||
def __init__(self, maxwidth, onSubmit):
|
||||
Widget.__init__(self)
|
||||
self.onSubmit = onSubmit
|
||||
self.maxwidth = maxwidth
|
||||
self.buffer = ''
|
||||
self.cursor = 0
|
||||
|
||||
def setText(self, text):
|
||||
self.buffer = text[:self.maxwidth]
|
||||
self.cursor = len(self.buffer)
|
||||
self.repaint()
|
||||
|
||||
def func_LEFT_ARROW(self, modifier):
|
||||
if self.cursor > 0:
|
||||
self.cursor -= 1
|
||||
self.repaint()
|
||||
|
||||
def func_RIGHT_ARROW(self, modifier):
|
||||
if self.cursor < len(self.buffer):
|
||||
self.cursor += 1
|
||||
self.repaint()
|
||||
|
||||
def backspaceReceived(self):
|
||||
if self.cursor > 0:
|
||||
self.buffer = self.buffer[:self.cursor - 1] + self.buffer[self.cursor:]
|
||||
self.cursor -= 1
|
||||
self.repaint()
|
||||
|
||||
def characterReceived(self, keyID, modifier):
|
||||
if keyID == '\r':
|
||||
self.onSubmit(self.buffer)
|
||||
else:
|
||||
if len(self.buffer) < self.maxwidth:
|
||||
self.buffer = self.buffer[:self.cursor] + keyID + self.buffer[self.cursor:]
|
||||
self.cursor += 1
|
||||
self.repaint()
|
||||
|
||||
def sizeHint(self):
|
||||
return self.maxwidth + 1, 1
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
currentText = self._renderText()
|
||||
terminal.cursorPosition(0, 0)
|
||||
if self.focused:
|
||||
terminal.write(currentText[:self.cursor])
|
||||
cursor(terminal, currentText[self.cursor:self.cursor+1] or ' ')
|
||||
terminal.write(currentText[self.cursor+1:])
|
||||
terminal.write(' ' * (self.maxwidth - len(currentText) + 1))
|
||||
else:
|
||||
more = self.maxwidth - len(currentText)
|
||||
terminal.write(currentText + '_' * more)
|
||||
|
||||
def _renderText(self):
|
||||
return self.buffer
|
||||
|
||||
class PasswordInput(TextInput):
|
||||
def _renderText(self):
|
||||
return '*' * len(self.buffer)
|
||||
|
||||
class TextOutput(Widget):
|
||||
text = ''
|
||||
|
||||
def __init__(self, size=None):
|
||||
Widget.__init__(self)
|
||||
self.size = size
|
||||
|
||||
def sizeHint(self):
|
||||
return self.size
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
terminal.cursorPosition(0, 0)
|
||||
text = self.text[:width]
|
||||
terminal.write(text + ' ' * (width - len(text)))
|
||||
|
||||
def setText(self, text):
|
||||
self.text = text
|
||||
self.repaint()
|
||||
|
||||
def focusReceived(self):
|
||||
raise YieldFocus()
|
||||
|
||||
class TextOutputArea(TextOutput):
|
||||
WRAP, TRUNCATE = range(2)
|
||||
|
||||
def __init__(self, size=None, longLines=WRAP):
|
||||
TextOutput.__init__(self, size)
|
||||
self.longLines = longLines
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
n = 0
|
||||
inputLines = self.text.splitlines()
|
||||
outputLines = []
|
||||
while inputLines:
|
||||
if self.longLines == self.WRAP:
|
||||
wrappedLines = tptext.greedyWrap(inputLines.pop(0), width)
|
||||
outputLines.extend(wrappedLines or [''])
|
||||
else:
|
||||
outputLines.append(inputLines.pop(0)[:width])
|
||||
if len(outputLines) >= height:
|
||||
break
|
||||
for n, L in enumerate(outputLines[:height]):
|
||||
terminal.cursorPosition(0, n)
|
||||
terminal.write(L)
|
||||
|
||||
class Viewport(Widget):
|
||||
_xOffset = 0
|
||||
_yOffset = 0
|
||||
|
||||
def xOffset():
|
||||
def get(self):
|
||||
return self._xOffset
|
||||
def set(self, value):
|
||||
if self._xOffset != value:
|
||||
self._xOffset = value
|
||||
self.repaint()
|
||||
return get, set
|
||||
xOffset = property(*xOffset())
|
||||
|
||||
def yOffset():
|
||||
def get(self):
|
||||
return self._yOffset
|
||||
def set(self, value):
|
||||
if self._yOffset != value:
|
||||
self._yOffset = value
|
||||
self.repaint()
|
||||
return get, set
|
||||
yOffset = property(*yOffset())
|
||||
|
||||
_width = 160
|
||||
_height = 24
|
||||
|
||||
def __init__(self, containee):
|
||||
Widget.__init__(self)
|
||||
self.containee = containee
|
||||
self.containee.parent = self
|
||||
|
||||
self._buf = helper.TerminalBuffer()
|
||||
self._buf.width = self._width
|
||||
self._buf.height = self._height
|
||||
self._buf.connectionMade()
|
||||
|
||||
def filthy(self):
|
||||
self.containee.filthy()
|
||||
Widget.filthy(self)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
self.containee.draw(self._width, self._height, self._buf)
|
||||
|
||||
# XXX /Lame/
|
||||
for y, line in enumerate(self._buf.lines[self._yOffset:self._yOffset + height]):
|
||||
terminal.cursorPosition(0, y)
|
||||
n = 0
|
||||
for n, (ch, attr) in enumerate(line[self._xOffset:self._xOffset + width]):
|
||||
if ch is self._buf.void:
|
||||
ch = ' '
|
||||
terminal.write(ch)
|
||||
if n < width:
|
||||
terminal.write(' ' * (width - n - 1))
|
||||
|
||||
|
||||
class _Scrollbar(Widget):
|
||||
def __init__(self, onScroll):
|
||||
Widget.__init__(self)
|
||||
self.onScroll = onScroll
|
||||
self.percent = 0.0
|
||||
|
||||
def smaller(self):
|
||||
self.percent = min(1.0, max(0.0, self.onScroll(-1)))
|
||||
self.repaint()
|
||||
|
||||
def bigger(self):
|
||||
self.percent = min(1.0, max(0.0, self.onScroll(+1)))
|
||||
self.repaint()
|
||||
|
||||
|
||||
class HorizontalScrollbar(_Scrollbar):
|
||||
def sizeHint(self):
|
||||
return (None, 1)
|
||||
|
||||
def func_LEFT_ARROW(self, modifier):
|
||||
self.smaller()
|
||||
|
||||
def func_RIGHT_ARROW(self, modifier):
|
||||
self.bigger()
|
||||
|
||||
_left = u'\N{BLACK LEFT-POINTING TRIANGLE}'
|
||||
_right = u'\N{BLACK RIGHT-POINTING TRIANGLE}'
|
||||
_bar = u'\N{LIGHT SHADE}'
|
||||
_slider = u'\N{DARK SHADE}'
|
||||
def render(self, width, height, terminal):
|
||||
terminal.cursorPosition(0, 0)
|
||||
n = width - 3
|
||||
before = int(n * self.percent)
|
||||
after = n - before
|
||||
me = self._left + (self._bar * before) + self._slider + (self._bar * after) + self._right
|
||||
terminal.write(me.encode('utf-8'))
|
||||
|
||||
|
||||
class VerticalScrollbar(_Scrollbar):
|
||||
def sizeHint(self):
|
||||
return (1, None)
|
||||
|
||||
def func_UP_ARROW(self, modifier):
|
||||
self.smaller()
|
||||
|
||||
def func_DOWN_ARROW(self, modifier):
|
||||
self.bigger()
|
||||
|
||||
_up = u'\N{BLACK UP-POINTING TRIANGLE}'
|
||||
_down = u'\N{BLACK DOWN-POINTING TRIANGLE}'
|
||||
_bar = u'\N{LIGHT SHADE}'
|
||||
_slider = u'\N{DARK SHADE}'
|
||||
def render(self, width, height, terminal):
|
||||
terminal.cursorPosition(0, 0)
|
||||
knob = int(self.percent * (height - 2))
|
||||
terminal.write(self._up.encode('utf-8'))
|
||||
for i in xrange(1, height - 1):
|
||||
terminal.cursorPosition(0, i)
|
||||
if i != (knob + 1):
|
||||
terminal.write(self._bar.encode('utf-8'))
|
||||
else:
|
||||
terminal.write(self._slider.encode('utf-8'))
|
||||
terminal.cursorPosition(0, height - 1)
|
||||
terminal.write(self._down.encode('utf-8'))
|
||||
|
||||
|
||||
class ScrolledArea(Widget):
|
||||
"""
|
||||
A L{ScrolledArea} contains another widget wrapped in a viewport and
|
||||
vertical and horizontal scrollbars for moving the viewport around.
|
||||
"""
|
||||
def __init__(self, containee):
|
||||
Widget.__init__(self)
|
||||
self._viewport = Viewport(containee)
|
||||
self._horiz = HorizontalScrollbar(self._horizScroll)
|
||||
self._vert = VerticalScrollbar(self._vertScroll)
|
||||
|
||||
for w in self._viewport, self._horiz, self._vert:
|
||||
w.parent = self
|
||||
|
||||
def _horizScroll(self, n):
|
||||
self._viewport.xOffset += n
|
||||
self._viewport.xOffset = max(0, self._viewport.xOffset)
|
||||
return self._viewport.xOffset / 25.0
|
||||
|
||||
def _vertScroll(self, n):
|
||||
self._viewport.yOffset += n
|
||||
self._viewport.yOffset = max(0, self._viewport.yOffset)
|
||||
return self._viewport.yOffset / 25.0
|
||||
|
||||
def func_UP_ARROW(self, modifier):
|
||||
self._vert.smaller()
|
||||
|
||||
def func_DOWN_ARROW(self, modifier):
|
||||
self._vert.bigger()
|
||||
|
||||
def func_LEFT_ARROW(self, modifier):
|
||||
self._horiz.smaller()
|
||||
|
||||
def func_RIGHT_ARROW(self, modifier):
|
||||
self._horiz.bigger()
|
||||
|
||||
def filthy(self):
|
||||
self._viewport.filthy()
|
||||
self._horiz.filthy()
|
||||
self._vert.filthy()
|
||||
Widget.filthy(self)
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
wrapper = BoundedTerminalWrapper(terminal, width - 2, height - 2, 1, 1)
|
||||
self._viewport.draw(width - 2, height - 2, wrapper)
|
||||
if self.focused:
|
||||
terminal.write('\x1b[31m')
|
||||
horizontalLine(terminal, 0, 1, width - 1)
|
||||
verticalLine(terminal, 0, 1, height - 1)
|
||||
self._vert.draw(1, height - 1, BoundedTerminalWrapper(terminal, 1, height - 1, width - 1, 0))
|
||||
self._horiz.draw(width, 1, BoundedTerminalWrapper(terminal, width, 1, 0, height - 1))
|
||||
terminal.write('\x1b[0m')
|
||||
|
||||
def cursor(terminal, ch):
|
||||
terminal.saveCursor()
|
||||
terminal.selectGraphicRendition(str(insults.REVERSE_VIDEO))
|
||||
terminal.write(ch)
|
||||
terminal.restoreCursor()
|
||||
terminal.cursorForward()
|
||||
|
||||
class Selection(Widget):
|
||||
# Index into the sequence
|
||||
focusedIndex = 0
|
||||
|
||||
# Offset into the displayed subset of the sequence
|
||||
renderOffset = 0
|
||||
|
||||
def __init__(self, sequence, onSelect, minVisible=None):
|
||||
Widget.__init__(self)
|
||||
self.sequence = sequence
|
||||
self.onSelect = onSelect
|
||||
self.minVisible = minVisible
|
||||
if minVisible is not None:
|
||||
self._width = max(map(len, self.sequence))
|
||||
|
||||
def sizeHint(self):
|
||||
if self.minVisible is not None:
|
||||
return self._width, self.minVisible
|
||||
|
||||
def func_UP_ARROW(self, modifier):
|
||||
if self.focusedIndex > 0:
|
||||
self.focusedIndex -= 1
|
||||
if self.renderOffset > 0:
|
||||
self.renderOffset -= 1
|
||||
self.repaint()
|
||||
|
||||
def func_PGUP(self, modifier):
|
||||
if self.renderOffset != 0:
|
||||
self.focusedIndex -= self.renderOffset
|
||||
self.renderOffset = 0
|
||||
else:
|
||||
self.focusedIndex = max(0, self.focusedIndex - self.height)
|
||||
self.repaint()
|
||||
|
||||
def func_DOWN_ARROW(self, modifier):
|
||||
if self.focusedIndex < len(self.sequence) - 1:
|
||||
self.focusedIndex += 1
|
||||
if self.renderOffset < self.height - 1:
|
||||
self.renderOffset += 1
|
||||
self.repaint()
|
||||
|
||||
|
||||
def func_PGDN(self, modifier):
|
||||
if self.renderOffset != self.height - 1:
|
||||
change = self.height - self.renderOffset - 1
|
||||
if change + self.focusedIndex >= len(self.sequence):
|
||||
change = len(self.sequence) - self.focusedIndex - 1
|
||||
self.focusedIndex += change
|
||||
self.renderOffset = self.height - 1
|
||||
else:
|
||||
self.focusedIndex = min(len(self.sequence) - 1, self.focusedIndex + self.height)
|
||||
self.repaint()
|
||||
|
||||
def characterReceived(self, keyID, modifier):
|
||||
if keyID == '\r':
|
||||
self.onSelect(self.sequence[self.focusedIndex])
|
||||
|
||||
def render(self, width, height, terminal):
|
||||
self.height = height
|
||||
start = self.focusedIndex - self.renderOffset
|
||||
if start > len(self.sequence) - height:
|
||||
start = max(0, len(self.sequence) - height)
|
||||
|
||||
elements = self.sequence[start:start+height]
|
||||
|
||||
for n, ele in enumerate(elements):
|
||||
terminal.cursorPosition(0, n)
|
||||
if n == self.renderOffset:
|
||||
terminal.saveCursor()
|
||||
if self.focused:
|
||||
modes = str(insults.REVERSE_VIDEO), str(insults.BOLD)
|
||||
else:
|
||||
modes = str(insults.REVERSE_VIDEO),
|
||||
terminal.selectGraphicRendition(*modes)
|
||||
text = ele[:width]
|
||||
terminal.write(text + (' ' * (width - len(text))))
|
||||
if n == self.renderOffset:
|
||||
terminal.restoreCursor()
|
@ -0,0 +1,408 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains interfaces defined for the L{twisted.conch} package.
|
||||
"""
|
||||
|
||||
from zope.interface import Interface, Attribute
|
||||
|
||||
class IConchUser(Interface):
|
||||
"""
|
||||
A user who has been authenticated to Cred through Conch. This is
|
||||
the interface between the SSH connection and the user.
|
||||
"""
|
||||
|
||||
conn = Attribute('The SSHConnection object for this user.')
|
||||
|
||||
def lookupChannel(channelType, windowSize, maxPacket, data):
|
||||
"""
|
||||
The other side requested a channel of some sort.
|
||||
channelType is the type of channel being requested,
|
||||
windowSize is the initial size of the remote window,
|
||||
maxPacket is the largest packet we should send,
|
||||
data is any other packet data (often nothing).
|
||||
|
||||
We return a subclass of L{SSHChannel<ssh.channel.SSHChannel>}. If
|
||||
an appropriate channel can not be found, an exception will be
|
||||
raised. If a L{ConchError<error.ConchError>} is raised, the .value
|
||||
will be the message, and the .data will be the error code.
|
||||
|
||||
@type channelType: C{str}
|
||||
@type windowSize: C{int}
|
||||
@type maxPacket: C{int}
|
||||
@type data: C{str}
|
||||
@rtype: subclass of L{SSHChannel}/C{tuple}
|
||||
"""
|
||||
|
||||
def lookupSubsystem(subsystem, data):
|
||||
"""
|
||||
The other side requested a subsystem.
|
||||
subsystem is the name of the subsystem being requested.
|
||||
data is any other packet data (often nothing).
|
||||
|
||||
We return a L{Protocol}.
|
||||
"""
|
||||
|
||||
def gotGlobalRequest(requestType, data):
|
||||
"""
|
||||
A global request was sent from the other side.
|
||||
|
||||
By default, this dispatches to a method 'channel_channelType' with any
|
||||
non-alphanumerics in the channelType replace with _'s. If it cannot
|
||||
find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
|
||||
The method is called with arguments of windowSize, maxPacket, data.
|
||||
"""
|
||||
|
||||
class ISession(Interface):
|
||||
|
||||
def getPty(term, windowSize, modes):
|
||||
"""
|
||||
Get a pseudo-terminal for use by a shell or command.
|
||||
|
||||
If a pseudo-terminal is not available, or the request otherwise
|
||||
fails, raise an exception.
|
||||
"""
|
||||
|
||||
def openShell(proto):
|
||||
"""
|
||||
Open a shell and connect it to proto.
|
||||
|
||||
@param proto: a L{ProcessProtocol} instance.
|
||||
"""
|
||||
|
||||
def execCommand(proto, command):
|
||||
"""
|
||||
Execute a command.
|
||||
|
||||
@param proto: a L{ProcessProtocol} instance.
|
||||
"""
|
||||
|
||||
def windowChanged(newWindowSize):
|
||||
"""
|
||||
Called when the size of the remote screen has changed.
|
||||
"""
|
||||
|
||||
def eofReceived():
|
||||
"""
|
||||
Called when the other side has indicated no more data will be sent.
|
||||
"""
|
||||
|
||||
def closed():
|
||||
"""
|
||||
Called when the session is closed.
|
||||
"""
|
||||
|
||||
|
||||
class ISFTPServer(Interface):
|
||||
"""
|
||||
SFTP subsystem for server-side communication.
|
||||
|
||||
Each method should check to verify that the user has permission for
|
||||
their actions.
|
||||
"""
|
||||
|
||||
avatar = Attribute(
|
||||
"""
|
||||
The avatar returned by the Realm that we are authenticated with,
|
||||
and represents the logged-in user.
|
||||
""")
|
||||
|
||||
def gotVersion(otherVersion, extData):
|
||||
"""
|
||||
Called when the client sends their version info.
|
||||
|
||||
otherVersion is an integer representing the version of the SFTP
|
||||
protocol they are claiming.
|
||||
extData is a dictionary of extended_name : extended_data items.
|
||||
These items are sent by the client to indicate additional features.
|
||||
|
||||
This method should return a dictionary of extended_name : extended_data
|
||||
items. These items are the additional features (if any) supported
|
||||
by the server.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def openFile(filename, flags, attrs):
|
||||
"""
|
||||
Called when the clients asks to open a file.
|
||||
|
||||
@param filename: a string representing the file to open.
|
||||
|
||||
@param flags: an integer of the flags to open the file with, ORed together.
|
||||
The flags and their values are listed at the bottom of this file.
|
||||
|
||||
@param attrs: a list of attributes to open the file with. It is a
|
||||
dictionary, consisting of 0 or more keys. The possible keys are::
|
||||
|
||||
size: the size of the file in bytes
|
||||
uid: the user ID of the file as an integer
|
||||
gid: the group ID of the file as an integer
|
||||
permissions: the permissions of the file with as an integer.
|
||||
the bit representation of this field is defined by POSIX.
|
||||
atime: the access time of the file as seconds since the epoch.
|
||||
mtime: the modification time of the file as seconds since the epoch.
|
||||
ext_*: extended attributes. The server is not required to
|
||||
understand this, but it may.
|
||||
|
||||
NOTE: there is no way to indicate text or binary files. it is up
|
||||
to the SFTP client to deal with this.
|
||||
|
||||
This method returns an object that meets the ISFTPFile interface.
|
||||
Alternatively, it can return a L{Deferred} that will be called back
|
||||
with the object.
|
||||
"""
|
||||
|
||||
def removeFile(filename):
|
||||
"""
|
||||
Remove the given file.
|
||||
|
||||
This method returns when the remove succeeds, or a Deferred that is
|
||||
called back when it succeeds.
|
||||
|
||||
@param filename: the name of the file as a string.
|
||||
"""
|
||||
|
||||
def renameFile(oldpath, newpath):
|
||||
"""
|
||||
Rename the given file.
|
||||
|
||||
This method returns when the rename succeeds, or a L{Deferred} that is
|
||||
called back when it succeeds. If the rename fails, C{renameFile} will
|
||||
raise an implementation-dependent exception.
|
||||
|
||||
@param oldpath: the current location of the file.
|
||||
@param newpath: the new file name.
|
||||
"""
|
||||
|
||||
def makeDirectory(path, attrs):
|
||||
"""
|
||||
Make a directory.
|
||||
|
||||
This method returns when the directory is created, or a Deferred that
|
||||
is called back when it is created.
|
||||
|
||||
@param path: the name of the directory to create as a string.
|
||||
@param attrs: a dictionary of attributes to create the directory with.
|
||||
Its meaning is the same as the attrs in the L{openFile} method.
|
||||
"""
|
||||
|
||||
def removeDirectory(path):
|
||||
"""
|
||||
Remove a directory (non-recursively)
|
||||
|
||||
It is an error to remove a directory that has files or directories in
|
||||
it.
|
||||
|
||||
This method returns when the directory is removed, or a Deferred that
|
||||
is called back when it is removed.
|
||||
|
||||
@param path: the directory to remove.
|
||||
"""
|
||||
|
||||
def openDirectory(path):
|
||||
"""
|
||||
Open a directory for scanning.
|
||||
|
||||
This method returns an iterable object that has a close() method,
|
||||
or a Deferred that is called back with same.
|
||||
|
||||
The close() method is called when the client is finished reading
|
||||
from the directory. At this point, the iterable will no longer
|
||||
be used.
|
||||
|
||||
The iterable should return triples of the form (filename,
|
||||
longname, attrs) or Deferreds that return the same. The
|
||||
sequence must support __getitem__, but otherwise may be any
|
||||
'sequence-like' object.
|
||||
|
||||
filename is the name of the file relative to the directory.
|
||||
logname is an expanded format of the filename. The recommended format
|
||||
is:
|
||||
-rwxr-xr-x 1 mjos staff 348911 Mar 25 14:29 t-filexfer
|
||||
1234567890 123 12345678 12345678 12345678 123456789012
|
||||
|
||||
The first line is sample output, the second is the length of the field.
|
||||
The fields are: permissions, link count, user owner, group owner,
|
||||
size in bytes, modification time.
|
||||
|
||||
attrs is a dictionary in the format of the attrs argument to openFile.
|
||||
|
||||
@param path: the directory to open.
|
||||
"""
|
||||
|
||||
def getAttrs(path, followLinks):
|
||||
"""
|
||||
Return the attributes for the given path.
|
||||
|
||||
This method returns a dictionary in the same format as the attrs
|
||||
argument to openFile or a Deferred that is called back with same.
|
||||
|
||||
@param path: the path to return attributes for as a string.
|
||||
@param followLinks: a boolean. If it is True, follow symbolic links
|
||||
and return attributes for the real path at the base. If it is False,
|
||||
return attributes for the specified path.
|
||||
"""
|
||||
|
||||
def setAttrs(path, attrs):
|
||||
"""
|
||||
Set the attributes for the path.
|
||||
|
||||
This method returns when the attributes are set or a Deferred that is
|
||||
called back when they are.
|
||||
|
||||
@param path: the path to set attributes for as a string.
|
||||
@param attrs: a dictionary in the same format as the attrs argument to
|
||||
L{openFile}.
|
||||
"""
|
||||
|
||||
def readLink(path):
|
||||
"""
|
||||
Find the root of a set of symbolic links.
|
||||
|
||||
This method returns the target of the link, or a Deferred that
|
||||
returns the same.
|
||||
|
||||
@param path: the path of the symlink to read.
|
||||
"""
|
||||
|
||||
def makeLink(linkPath, targetPath):
|
||||
"""
|
||||
Create a symbolic link.
|
||||
|
||||
This method returns when the link is made, or a Deferred that
|
||||
returns the same.
|
||||
|
||||
@param linkPath: the pathname of the symlink as a string.
|
||||
@param targetPath: the path of the target of the link as a string.
|
||||
"""
|
||||
|
||||
def realPath(path):
|
||||
"""
|
||||
Convert any path to an absolute path.
|
||||
|
||||
This method returns the absolute path as a string, or a Deferred
|
||||
that returns the same.
|
||||
|
||||
@param path: the path to convert as a string.
|
||||
"""
|
||||
|
||||
def extendedRequest(extendedName, extendedData):
|
||||
"""
|
||||
This is the extension mechanism for SFTP. The other side can send us
|
||||
arbitrary requests.
|
||||
|
||||
If we don't implement the request given by extendedName, raise
|
||||
NotImplementedError.
|
||||
|
||||
The return value is a string, or a Deferred that will be called
|
||||
back with a string.
|
||||
|
||||
@param extendedName: the name of the request as a string.
|
||||
@param extendedData: the data the other side sent with the request,
|
||||
as a string.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class IKnownHostEntry(Interface):
|
||||
"""
|
||||
A L{IKnownHostEntry} is an entry in an OpenSSH-formatted C{known_hosts}
|
||||
file.
|
||||
|
||||
@since: 8.2
|
||||
"""
|
||||
|
||||
def matchesKey(key):
|
||||
"""
|
||||
Return True if this entry matches the given Key object, False
|
||||
otherwise.
|
||||
|
||||
@param key: The key object to match against.
|
||||
@type key: L{twisted.conch.ssh.Key}
|
||||
"""
|
||||
|
||||
|
||||
def matchesHost(hostname):
|
||||
"""
|
||||
Return True if this entry matches the given hostname, False otherwise.
|
||||
|
||||
Note that this does no name resolution; if you want to match an IP
|
||||
address, you have to resolve it yourself, and pass it in as a dotted
|
||||
quad string.
|
||||
|
||||
@param key: The hostname to match against.
|
||||
@type key: L{str}
|
||||
"""
|
||||
|
||||
|
||||
def toString():
|
||||
"""
|
||||
@return: a serialized string representation of this entry, suitable for
|
||||
inclusion in a known_hosts file. (Newline not included.)
|
||||
|
||||
@rtype: L{str}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class ISFTPFile(Interface):
|
||||
"""
|
||||
This represents an open file on the server. An object adhering to this
|
||||
interface should be returned from L{openFile}().
|
||||
"""
|
||||
|
||||
def close():
|
||||
"""
|
||||
Close the file.
|
||||
|
||||
This method returns nothing if the close succeeds immediately, or a
|
||||
Deferred that is called back when the close succeeds.
|
||||
"""
|
||||
|
||||
def readChunk(offset, length):
|
||||
"""
|
||||
Read from the file.
|
||||
|
||||
If EOF is reached before any data is read, raise EOFError.
|
||||
|
||||
This method returns the data as a string, or a Deferred that is
|
||||
called back with same.
|
||||
|
||||
@param offset: an integer that is the index to start from in the file.
|
||||
@param length: the maximum length of data to return. The actual amount
|
||||
returned may less than this. For normal disk files, however,
|
||||
this should read the requested number (up to the end of the file).
|
||||
"""
|
||||
|
||||
def writeChunk(offset, data):
|
||||
"""
|
||||
Write to the file.
|
||||
|
||||
This method returns when the write completes, or a Deferred that is
|
||||
called when it completes.
|
||||
|
||||
@param offset: an integer that is the index to start from in the file.
|
||||
@param data: a string that is the data to write.
|
||||
"""
|
||||
|
||||
def getAttrs():
|
||||
"""
|
||||
Return the attributes for the file.
|
||||
|
||||
This method returns a dictionary in the same format as the attrs
|
||||
argument to L{openFile} or a L{Deferred} that is called back with same.
|
||||
"""
|
||||
|
||||
def setAttrs(attrs):
|
||||
"""
|
||||
Set the attributes for the file.
|
||||
|
||||
This method returns when the attributes are set or a Deferred that is
|
||||
called back when they are.
|
||||
|
||||
@param attrs: a dictionary in the same format as the attrs argument to
|
||||
L{openFile}.
|
||||
"""
|
||||
|
||||
|
@ -0,0 +1,75 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_cftp -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import array
|
||||
import stat
|
||||
|
||||
from time import time, strftime, localtime
|
||||
|
||||
# locale-independent month names to use instead of strftime's
|
||||
_MONTH_NAMES = dict(zip(
|
||||
range(1, 13),
|
||||
"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split()))
|
||||
|
||||
|
||||
def lsLine(name, s):
|
||||
"""
|
||||
Build an 'ls' line for a file ('file' in its generic sense, it
|
||||
can be of any type).
|
||||
"""
|
||||
mode = s.st_mode
|
||||
perms = array.array('c', '-'*10)
|
||||
ft = stat.S_IFMT(mode)
|
||||
if stat.S_ISDIR(ft): perms[0] = 'd'
|
||||
elif stat.S_ISCHR(ft): perms[0] = 'c'
|
||||
elif stat.S_ISBLK(ft): perms[0] = 'b'
|
||||
elif stat.S_ISREG(ft): perms[0] = '-'
|
||||
elif stat.S_ISFIFO(ft): perms[0] = 'f'
|
||||
elif stat.S_ISLNK(ft): perms[0] = 'l'
|
||||
elif stat.S_ISSOCK(ft): perms[0] = 's'
|
||||
else: perms[0] = '!'
|
||||
# user
|
||||
if mode&stat.S_IRUSR:perms[1] = 'r'
|
||||
if mode&stat.S_IWUSR:perms[2] = 'w'
|
||||
if mode&stat.S_IXUSR:perms[3] = 'x'
|
||||
# group
|
||||
if mode&stat.S_IRGRP:perms[4] = 'r'
|
||||
if mode&stat.S_IWGRP:perms[5] = 'w'
|
||||
if mode&stat.S_IXGRP:perms[6] = 'x'
|
||||
# other
|
||||
if mode&stat.S_IROTH:perms[7] = 'r'
|
||||
if mode&stat.S_IWOTH:perms[8] = 'w'
|
||||
if mode&stat.S_IXOTH:perms[9] = 'x'
|
||||
# suid/sgid
|
||||
if mode&stat.S_ISUID:
|
||||
if perms[3] == 'x': perms[3] = 's'
|
||||
else: perms[3] = 'S'
|
||||
if mode&stat.S_ISGID:
|
||||
if perms[6] == 'x': perms[6] = 's'
|
||||
else: perms[6] = 'S'
|
||||
|
||||
lsresult = [
|
||||
perms.tostring(),
|
||||
str(s.st_nlink).rjust(5),
|
||||
' ',
|
||||
str(s.st_uid).ljust(9),
|
||||
str(s.st_gid).ljust(9),
|
||||
str(s.st_size).rjust(8),
|
||||
' ',
|
||||
]
|
||||
|
||||
# need to specify the month manually, as strftime depends on locale
|
||||
ttup = localtime(s.st_mtime)
|
||||
sixmonths = 60 * 60 * 24 * 7 * 26
|
||||
if s.st_mtime + sixmonths < time(): # last edited more than 6mo ago
|
||||
strtime = strftime("%%s %d %Y ", ttup)
|
||||
else:
|
||||
strtime = strftime("%%s %d %H:%M ", ttup)
|
||||
lsresult.append(strtime % (_MONTH_NAMES[ttup[1]],))
|
||||
|
||||
lsresult.append(name)
|
||||
return ''.join(lsresult)
|
||||
|
||||
|
||||
__all__ = ['lsLine']
|
@ -0,0 +1,340 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_manhole -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Line-input oriented interactive interpreter loop.
|
||||
|
||||
Provides classes for handling Python source input and arbitrary output
|
||||
interactively from a Twisted application. Also included is syntax coloring
|
||||
code with support for VT102 terminals, control code handling (^C, ^D, ^Q),
|
||||
and reasonable handling of Deferreds.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
import code, sys, StringIO, tokenize
|
||||
|
||||
from twisted.conch import recvline
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python.htmlizer import TokenPrinter
|
||||
|
||||
class FileWrapper:
|
||||
"""Minimal write-file-like object.
|
||||
|
||||
Writes are translated into addOutput calls on an object passed to
|
||||
__init__. Newlines are also converted from network to local style.
|
||||
"""
|
||||
|
||||
softspace = 0
|
||||
state = 'normal'
|
||||
|
||||
def __init__(self, o):
|
||||
self.o = o
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def write(self, data):
|
||||
self.o.addOutput(data.replace('\r\n', '\n'))
|
||||
|
||||
def writelines(self, lines):
|
||||
self.write(''.join(lines))
|
||||
|
||||
class ManholeInterpreter(code.InteractiveInterpreter):
|
||||
"""Interactive Interpreter with special output and Deferred support.
|
||||
|
||||
Aside from the features provided by L{code.InteractiveInterpreter}, this
|
||||
class captures sys.stdout output and redirects it to the appropriate
|
||||
location (the Manhole protocol instance). It also treats Deferreds
|
||||
which reach the top-level specially: each is formatted to the user with
|
||||
a unique identifier and a new callback and errback added to it, each of
|
||||
which will format the unique identifier and the result with which the
|
||||
Deferred fires and then pass it on to the next participant in the
|
||||
callback chain.
|
||||
"""
|
||||
|
||||
numDeferreds = 0
|
||||
def __init__(self, handler, locals=None, filename="<console>"):
|
||||
code.InteractiveInterpreter.__init__(self, locals)
|
||||
self._pendingDeferreds = {}
|
||||
self.handler = handler
|
||||
self.filename = filename
|
||||
self.resetBuffer()
|
||||
|
||||
def resetBuffer(self):
|
||||
"""Reset the input buffer."""
|
||||
self.buffer = []
|
||||
|
||||
def push(self, line):
|
||||
"""Push a line to the interpreter.
|
||||
|
||||
The line should not have a trailing newline; it may have
|
||||
internal newlines. The line is appended to a buffer and the
|
||||
interpreter's runsource() method is called with the
|
||||
concatenated contents of the buffer as source. If this
|
||||
indicates that the command was executed or invalid, the buffer
|
||||
is reset; otherwise, the command is incomplete, and the buffer
|
||||
is left as it was after the line was appended. The return
|
||||
value is 1 if more input is required, 0 if the line was dealt
|
||||
with in some way (this is the same as runsource()).
|
||||
|
||||
"""
|
||||
self.buffer.append(line)
|
||||
source = "\n".join(self.buffer)
|
||||
more = self.runsource(source, self.filename)
|
||||
if not more:
|
||||
self.resetBuffer()
|
||||
return more
|
||||
|
||||
def runcode(self, *a, **kw):
|
||||
orighook, sys.displayhook = sys.displayhook, self.displayhook
|
||||
try:
|
||||
origout, sys.stdout = sys.stdout, FileWrapper(self.handler)
|
||||
try:
|
||||
code.InteractiveInterpreter.runcode(self, *a, **kw)
|
||||
finally:
|
||||
sys.stdout = origout
|
||||
finally:
|
||||
sys.displayhook = orighook
|
||||
|
||||
def displayhook(self, obj):
|
||||
self.locals['_'] = obj
|
||||
if isinstance(obj, defer.Deferred):
|
||||
# XXX Ick, where is my "hasFired()" interface?
|
||||
if hasattr(obj, "result"):
|
||||
self.write(repr(obj))
|
||||
elif id(obj) in self._pendingDeferreds:
|
||||
self.write("<Deferred #%d>" % (self._pendingDeferreds[id(obj)][0],))
|
||||
else:
|
||||
d = self._pendingDeferreds
|
||||
k = self.numDeferreds
|
||||
d[id(obj)] = (k, obj)
|
||||
self.numDeferreds += 1
|
||||
obj.addCallbacks(self._cbDisplayDeferred, self._ebDisplayDeferred,
|
||||
callbackArgs=(k, obj), errbackArgs=(k, obj))
|
||||
self.write("<Deferred #%d>" % (k,))
|
||||
elif obj is not None:
|
||||
self.write(repr(obj))
|
||||
|
||||
def _cbDisplayDeferred(self, result, k, obj):
|
||||
self.write("Deferred #%d called back: %r" % (k, result), True)
|
||||
del self._pendingDeferreds[id(obj)]
|
||||
return result
|
||||
|
||||
def _ebDisplayDeferred(self, failure, k, obj):
|
||||
self.write("Deferred #%d failed: %r" % (k, failure.getErrorMessage()), True)
|
||||
del self._pendingDeferreds[id(obj)]
|
||||
return failure
|
||||
|
||||
def write(self, data, async=False):
|
||||
self.handler.addOutput(data, async)
|
||||
|
||||
CTRL_C = '\x03'
|
||||
CTRL_D = '\x04'
|
||||
CTRL_BACKSLASH = '\x1c'
|
||||
CTRL_L = '\x0c'
|
||||
CTRL_A = '\x01'
|
||||
CTRL_E = '\x05'
|
||||
|
||||
class Manhole(recvline.HistoricRecvLine):
|
||||
"""Mediator between a fancy line source and an interactive interpreter.
|
||||
|
||||
This accepts lines from its transport and passes them on to a
|
||||
L{ManholeInterpreter}. Control commands (^C, ^D, ^\) are also handled
|
||||
with something approximating their normal terminal-mode behavior. It
|
||||
can optionally be constructed with a dict which will be used as the
|
||||
local namespace for any code executed.
|
||||
"""
|
||||
|
||||
namespace = None
|
||||
|
||||
def __init__(self, namespace=None):
|
||||
recvline.HistoricRecvLine.__init__(self)
|
||||
if namespace is not None:
|
||||
self.namespace = namespace.copy()
|
||||
|
||||
def connectionMade(self):
|
||||
recvline.HistoricRecvLine.connectionMade(self)
|
||||
self.interpreter = ManholeInterpreter(self, self.namespace)
|
||||
self.keyHandlers[CTRL_C] = self.handle_INT
|
||||
self.keyHandlers[CTRL_D] = self.handle_EOF
|
||||
self.keyHandlers[CTRL_L] = self.handle_FF
|
||||
self.keyHandlers[CTRL_A] = self.handle_HOME
|
||||
self.keyHandlers[CTRL_E] = self.handle_END
|
||||
self.keyHandlers[CTRL_BACKSLASH] = self.handle_QUIT
|
||||
|
||||
|
||||
def handle_INT(self):
|
||||
"""
|
||||
Handle ^C as an interrupt keystroke by resetting the current input
|
||||
variables to their initial state.
|
||||
"""
|
||||
self.pn = 0
|
||||
self.lineBuffer = []
|
||||
self.lineBufferIndex = 0
|
||||
self.interpreter.resetBuffer()
|
||||
|
||||
self.terminal.nextLine()
|
||||
self.terminal.write("KeyboardInterrupt")
|
||||
self.terminal.nextLine()
|
||||
self.terminal.write(self.ps[self.pn])
|
||||
|
||||
|
||||
def handle_EOF(self):
|
||||
if self.lineBuffer:
|
||||
self.terminal.write('\a')
|
||||
else:
|
||||
self.handle_QUIT()
|
||||
|
||||
|
||||
def handle_FF(self):
|
||||
"""
|
||||
Handle a 'form feed' byte - generally used to request a screen
|
||||
refresh/redraw.
|
||||
"""
|
||||
self.terminal.eraseDisplay()
|
||||
self.terminal.cursorHome()
|
||||
self.drawInputLine()
|
||||
|
||||
|
||||
def handle_QUIT(self):
|
||||
self.terminal.loseConnection()
|
||||
|
||||
|
||||
def _needsNewline(self):
|
||||
w = self.terminal.lastWrite
|
||||
return not w.endswith('\n') and not w.endswith('\x1bE')
|
||||
|
||||
def addOutput(self, bytes, async=False):
|
||||
if async:
|
||||
self.terminal.eraseLine()
|
||||
self.terminal.cursorBackward(len(self.lineBuffer) + len(self.ps[self.pn]))
|
||||
|
||||
self.terminal.write(bytes)
|
||||
|
||||
if async:
|
||||
if self._needsNewline():
|
||||
self.terminal.nextLine()
|
||||
|
||||
self.terminal.write(self.ps[self.pn])
|
||||
|
||||
if self.lineBuffer:
|
||||
oldBuffer = self.lineBuffer
|
||||
self.lineBuffer = []
|
||||
self.lineBufferIndex = 0
|
||||
|
||||
self._deliverBuffer(oldBuffer)
|
||||
|
||||
def lineReceived(self, line):
|
||||
more = self.interpreter.push(line)
|
||||
self.pn = bool(more)
|
||||
if self._needsNewline():
|
||||
self.terminal.nextLine()
|
||||
self.terminal.write(self.ps[self.pn])
|
||||
|
||||
class VT102Writer:
|
||||
"""Colorizer for Python tokens.
|
||||
|
||||
A series of tokens are written to instances of this object. Each is
|
||||
colored in a particular way. The final line of the result of this is
|
||||
generally added to the output.
|
||||
"""
|
||||
|
||||
typeToColor = {
|
||||
'identifier': '\x1b[31m',
|
||||
'keyword': '\x1b[32m',
|
||||
'parameter': '\x1b[33m',
|
||||
'variable': '\x1b[1;33m',
|
||||
'string': '\x1b[35m',
|
||||
'number': '\x1b[36m',
|
||||
'op': '\x1b[37m'}
|
||||
|
||||
normalColor = '\x1b[0m'
|
||||
|
||||
def __init__(self):
|
||||
self.written = []
|
||||
|
||||
def color(self, type):
|
||||
r = self.typeToColor.get(type, '')
|
||||
return r
|
||||
|
||||
def write(self, token, type=None):
|
||||
if token and token != '\r':
|
||||
c = self.color(type)
|
||||
if c:
|
||||
self.written.append(c)
|
||||
self.written.append(token)
|
||||
if c:
|
||||
self.written.append(self.normalColor)
|
||||
|
||||
def __str__(self):
|
||||
s = ''.join(self.written)
|
||||
return s.strip('\n').splitlines()[-1]
|
||||
|
||||
def lastColorizedLine(source):
|
||||
"""Tokenize and colorize the given Python source.
|
||||
|
||||
Returns a VT102-format colorized version of the last line of C{source}.
|
||||
"""
|
||||
w = VT102Writer()
|
||||
p = TokenPrinter(w.write).printtoken
|
||||
s = StringIO.StringIO(source)
|
||||
|
||||
tokenize.tokenize(s.readline, p)
|
||||
|
||||
return str(w)
|
||||
|
||||
class ColoredManhole(Manhole):
|
||||
"""A REPL which syntax colors input as users type it.
|
||||
"""
|
||||
|
||||
def getSource(self):
|
||||
"""Return a string containing the currently entered source.
|
||||
|
||||
This is only the code which will be considered for execution
|
||||
next.
|
||||
"""
|
||||
return ('\n'.join(self.interpreter.buffer) +
|
||||
'\n' +
|
||||
''.join(self.lineBuffer))
|
||||
|
||||
|
||||
def characterReceived(self, ch, moreCharactersComing):
|
||||
if self.mode == 'insert':
|
||||
self.lineBuffer.insert(self.lineBufferIndex, ch)
|
||||
else:
|
||||
self.lineBuffer[self.lineBufferIndex:self.lineBufferIndex+1] = [ch]
|
||||
self.lineBufferIndex += 1
|
||||
|
||||
if moreCharactersComing:
|
||||
# Skip it all, we'll get called with another character in
|
||||
# like 2 femtoseconds.
|
||||
return
|
||||
|
||||
if ch == ' ':
|
||||
# Don't bother to try to color whitespace
|
||||
self.terminal.write(ch)
|
||||
return
|
||||
|
||||
source = self.getSource()
|
||||
|
||||
# Try to write some junk
|
||||
try:
|
||||
coloredLine = lastColorizedLine(source)
|
||||
except tokenize.TokenError:
|
||||
# We couldn't do it. Strange. Oh well, just add the character.
|
||||
self.terminal.write(ch)
|
||||
else:
|
||||
# Success! Clear the source on this line.
|
||||
self.terminal.eraseLine()
|
||||
self.terminal.cursorBackward(len(self.lineBuffer) + len(self.ps[self.pn]) - 1)
|
||||
|
||||
# And write a new, colorized one.
|
||||
self.terminal.write(self.ps[self.pn] + coloredLine)
|
||||
|
||||
# And move the cursor to where it belongs
|
||||
n = len(self.lineBuffer) - self.lineBufferIndex
|
||||
if n:
|
||||
self.terminal.cursorBackward(n)
|
@ -0,0 +1,146 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_manhole -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
insults/SSH integration support.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.conch import avatar, interfaces as iconch, error as econch
|
||||
from twisted.conch.ssh import factory, keys, session
|
||||
from twisted.python import components
|
||||
|
||||
from twisted.conch.insults import insults
|
||||
|
||||
class _Glue:
|
||||
"""A feeble class for making one attribute look like another.
|
||||
|
||||
This should be replaced with a real class at some point, probably.
|
||||
Try not to write new code that uses it.
|
||||
"""
|
||||
def __init__(self, **kw):
|
||||
self.__dict__.update(kw)
|
||||
|
||||
def __getattr__(self, name):
|
||||
raise AttributeError(self.name, "has no attribute", name)
|
||||
|
||||
class TerminalSessionTransport:
|
||||
def __init__(self, proto, chainedProtocol, avatar, width, height):
|
||||
self.proto = proto
|
||||
self.avatar = avatar
|
||||
self.chainedProtocol = chainedProtocol
|
||||
|
||||
protoSession = self.proto.session
|
||||
|
||||
self.proto.makeConnection(
|
||||
_Glue(write=self.chainedProtocol.dataReceived,
|
||||
loseConnection=lambda: avatar.conn.sendClose(protoSession),
|
||||
name="SSH Proto Transport"))
|
||||
|
||||
def loseConnection():
|
||||
self.proto.loseConnection()
|
||||
|
||||
self.chainedProtocol.makeConnection(
|
||||
_Glue(write=self.proto.write,
|
||||
loseConnection=loseConnection,
|
||||
name="Chained Proto Transport"))
|
||||
|
||||
# XXX TODO
|
||||
# chainedProtocol is supposed to be an ITerminalTransport,
|
||||
# maybe. That means perhaps its terminalProtocol attribute is
|
||||
# an ITerminalProtocol, it could be. So calling terminalSize
|
||||
# on that should do the right thing But it'd be nice to clean
|
||||
# this bit up.
|
||||
self.chainedProtocol.terminalProtocol.terminalSize(width, height)
|
||||
|
||||
|
||||
|
||||
@implementer(iconch.ISession)
|
||||
class TerminalSession(components.Adapter):
|
||||
transportFactory = TerminalSessionTransport
|
||||
chainedProtocolFactory = insults.ServerProtocol
|
||||
|
||||
def getPty(self, term, windowSize, attrs):
|
||||
self.height, self.width = windowSize[:2]
|
||||
|
||||
def openShell(self, proto):
|
||||
self.transportFactory(
|
||||
proto, self.chainedProtocolFactory(),
|
||||
iconch.IConchUser(self.original),
|
||||
self.width, self.height)
|
||||
|
||||
def execCommand(self, proto, cmd):
|
||||
raise econch.ConchError("Cannot execute commands")
|
||||
|
||||
def closed(self):
|
||||
pass
|
||||
|
||||
class TerminalUser(avatar.ConchUser, components.Adapter):
|
||||
def __init__(self, original, avatarId):
|
||||
components.Adapter.__init__(self, original)
|
||||
avatar.ConchUser.__init__(self)
|
||||
self.channelLookup['session'] = session.SSHSession
|
||||
|
||||
class TerminalRealm:
|
||||
userFactory = TerminalUser
|
||||
sessionFactory = TerminalSession
|
||||
|
||||
transportFactory = TerminalSessionTransport
|
||||
chainedProtocolFactory = insults.ServerProtocol
|
||||
|
||||
def _getAvatar(self, avatarId):
|
||||
comp = components.Componentized()
|
||||
user = self.userFactory(comp, avatarId)
|
||||
sess = self.sessionFactory(comp)
|
||||
|
||||
sess.transportFactory = self.transportFactory
|
||||
sess.chainedProtocolFactory = self.chainedProtocolFactory
|
||||
|
||||
comp.setComponent(iconch.IConchUser, user)
|
||||
comp.setComponent(iconch.ISession, sess)
|
||||
|
||||
return user
|
||||
|
||||
def __init__(self, transportFactory=None):
|
||||
if transportFactory is not None:
|
||||
self.transportFactory = transportFactory
|
||||
|
||||
def requestAvatar(self, avatarId, mind, *interfaces):
|
||||
for i in interfaces:
|
||||
if i is iconch.IConchUser:
|
||||
return (iconch.IConchUser,
|
||||
self._getAvatar(avatarId),
|
||||
lambda: None)
|
||||
raise NotImplementedError()
|
||||
|
||||
class ConchFactory(factory.SSHFactory):
|
||||
publicKey = 'ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBEvLi8DVPrJ3/c9k2I/Az64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYLh5KmRpslkYHRivcJSkbh/C+BR3utDS555mV'
|
||||
|
||||
publicKeys = {
|
||||
'ssh-rsa' : keys.Key.fromString(publicKey)
|
||||
}
|
||||
del publicKey
|
||||
|
||||
privateKey = """-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
|
||||
4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
|
||||
vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
|
||||
Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
|
||||
xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
|
||||
PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
|
||||
gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
|
||||
DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
|
||||
pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
|
||||
EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
privateKeys = {
|
||||
'ssh-rsa' : keys.Key.fromString(privateKey)
|
||||
}
|
||||
del privateKey
|
||||
|
||||
def __init__(self, portal):
|
||||
self.portal = portal
|
@ -0,0 +1,123 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
TAP plugin for creating telnet- and ssh-accessible manhole servers.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import protocol
|
||||
from twisted.application import service, strports
|
||||
from twisted.cred import portal, checkers
|
||||
from twisted.python import usage
|
||||
|
||||
from twisted.conch.insults import insults
|
||||
from twisted.conch import manhole, manhole_ssh, telnet
|
||||
|
||||
class makeTelnetProtocol:
|
||||
def __init__(self, portal):
|
||||
self.portal = portal
|
||||
|
||||
def __call__(self):
|
||||
auth = telnet.AuthenticatingTelnetProtocol
|
||||
args = (self.portal,)
|
||||
return telnet.TelnetTransport(auth, *args)
|
||||
|
||||
class chainedProtocolFactory:
|
||||
def __init__(self, namespace):
|
||||
self.namespace = namespace
|
||||
|
||||
def __call__(self):
|
||||
return insults.ServerProtocol(manhole.ColoredManhole, self.namespace)
|
||||
|
||||
|
||||
|
||||
@implementer(portal.IRealm)
|
||||
class _StupidRealm:
|
||||
def __init__(self, proto, *a, **kw):
|
||||
self.protocolFactory = proto
|
||||
self.protocolArgs = a
|
||||
self.protocolKwArgs = kw
|
||||
|
||||
def requestAvatar(self, avatarId, *interfaces):
|
||||
if telnet.ITelnetProtocol in interfaces:
|
||||
return (telnet.ITelnetProtocol,
|
||||
self.protocolFactory(*self.protocolArgs, **self.protocolKwArgs),
|
||||
lambda: None)
|
||||
raise NotImplementedError()
|
||||
|
||||
class Options(usage.Options):
|
||||
optParameters = [
|
||||
["telnetPort", "t", None, "strports description of the address on which to listen for telnet connections"],
|
||||
["sshPort", "s", None, "strports description of the address on which to listen for ssh connections"],
|
||||
["passwd", "p", "/etc/passwd", "name of a passwd(5)-format username/password file"]]
|
||||
|
||||
def __init__(self):
|
||||
usage.Options.__init__(self)
|
||||
self['namespace'] = None
|
||||
|
||||
def postOptions(self):
|
||||
if self['telnetPort'] is None and self['sshPort'] is None:
|
||||
raise usage.UsageError("At least one of --telnetPort and --sshPort must be specified")
|
||||
|
||||
def makeService(options):
|
||||
"""Create a manhole server service.
|
||||
|
||||
@type options: C{dict}
|
||||
@param options: A mapping describing the configuration of
|
||||
the desired service. Recognized key/value pairs are::
|
||||
|
||||
"telnetPort": strports description of the address on which
|
||||
to listen for telnet connections. If None,
|
||||
no telnet service will be started.
|
||||
|
||||
"sshPort": strports description of the address on which to
|
||||
listen for ssh connections. If None, no ssh
|
||||
service will be started.
|
||||
|
||||
"namespace": dictionary containing desired initial locals
|
||||
for manhole connections. If None, an empty
|
||||
dictionary will be used.
|
||||
|
||||
"passwd": Name of a passwd(5)-format username/password file.
|
||||
|
||||
@rtype: L{twisted.application.service.IService}
|
||||
@return: A manhole service.
|
||||
"""
|
||||
|
||||
svc = service.MultiService()
|
||||
|
||||
namespace = options['namespace']
|
||||
if namespace is None:
|
||||
namespace = {}
|
||||
|
||||
checker = checkers.FilePasswordDB(options['passwd'])
|
||||
|
||||
if options['telnetPort']:
|
||||
telnetRealm = _StupidRealm(telnet.TelnetBootstrapProtocol,
|
||||
insults.ServerProtocol,
|
||||
manhole.ColoredManhole,
|
||||
namespace)
|
||||
|
||||
telnetPortal = portal.Portal(telnetRealm, [checker])
|
||||
|
||||
telnetFactory = protocol.ServerFactory()
|
||||
telnetFactory.protocol = makeTelnetProtocol(telnetPortal)
|
||||
telnetService = strports.service(options['telnetPort'],
|
||||
telnetFactory)
|
||||
telnetService.setServiceParent(svc)
|
||||
|
||||
if options['sshPort']:
|
||||
sshRealm = manhole_ssh.TerminalRealm()
|
||||
sshRealm.chainedProtocolFactory = chainedProtocolFactory(namespace)
|
||||
|
||||
sshPortal = portal.Portal(sshRealm, [checker])
|
||||
sshFactory = manhole_ssh.ConchFactory(sshPortal)
|
||||
sshService = strports.service(options['sshPort'],
|
||||
sshFactory)
|
||||
sshService.setServiceParent(svc)
|
||||
|
||||
return svc
|
@ -0,0 +1,49 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_mixin -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Experimental optimization
|
||||
|
||||
This module provides a single mixin class which allows protocols to
|
||||
collapse numerous small writes into a single larger one.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
from twisted.internet import reactor
|
||||
|
||||
class BufferingMixin:
|
||||
"""Mixin which adds write buffering.
|
||||
"""
|
||||
_delayedWriteCall = None
|
||||
bytes = None
|
||||
|
||||
DELAY = 0.0
|
||||
|
||||
def schedule(self):
|
||||
return reactor.callLater(self.DELAY, self.flush)
|
||||
|
||||
def reschedule(self, token):
|
||||
token.reset(self.DELAY)
|
||||
|
||||
def write(self, bytes):
|
||||
"""Buffer some bytes to be written soon.
|
||||
|
||||
Every call to this function delays the real write by C{self.DELAY}
|
||||
seconds. When the delay expires, all collected bytes are written
|
||||
to the underlying transport using L{ITransport.writeSequence}.
|
||||
"""
|
||||
if self._delayedWriteCall is None:
|
||||
self.bytes = []
|
||||
self._delayedWriteCall = self.schedule()
|
||||
else:
|
||||
self.reschedule(self._delayedWriteCall)
|
||||
self.bytes.append(bytes)
|
||||
|
||||
def flush(self):
|
||||
"""Flush the buffer immediately.
|
||||
"""
|
||||
self._delayedWriteCall = None
|
||||
self.transport.writeSequence(self.bytes)
|
||||
self.bytes = None
|
@ -0,0 +1,11 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
|
||||
"""
|
||||
Support for OpenSSH configuration files.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
@ -0,0 +1,73 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_openssh_compat -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Factory for reading openssh configuration files: public keys, private keys, and
|
||||
moduli file.
|
||||
"""
|
||||
|
||||
import os, errno
|
||||
|
||||
from twisted.python import log
|
||||
from twisted.python.util import runAsEffectiveUser
|
||||
|
||||
from twisted.conch.ssh import keys, factory, common
|
||||
from twisted.conch.openssh_compat import primes
|
||||
|
||||
|
||||
|
||||
class OpenSSHFactory(factory.SSHFactory):
|
||||
dataRoot = '/usr/local/etc'
|
||||
moduliRoot = '/usr/local/etc' # for openbsd which puts moduli in a different
|
||||
# directory from keys
|
||||
|
||||
|
||||
def getPublicKeys(self):
|
||||
"""
|
||||
Return the server public keys.
|
||||
"""
|
||||
ks = {}
|
||||
for filename in os.listdir(self.dataRoot):
|
||||
if filename[:9] == 'ssh_host_' and filename[-8:]=='_key.pub':
|
||||
try:
|
||||
k = keys.Key.fromFile(
|
||||
os.path.join(self.dataRoot, filename))
|
||||
t = common.getNS(k.blob())[0]
|
||||
ks[t] = k
|
||||
except Exception, e:
|
||||
log.msg('bad public key file %s: %s' % (filename, e))
|
||||
return ks
|
||||
|
||||
|
||||
def getPrivateKeys(self):
|
||||
"""
|
||||
Return the server private keys.
|
||||
"""
|
||||
privateKeys = {}
|
||||
for filename in os.listdir(self.dataRoot):
|
||||
if filename[:9] == 'ssh_host_' and filename[-4:]=='_key':
|
||||
fullPath = os.path.join(self.dataRoot, filename)
|
||||
try:
|
||||
key = keys.Key.fromFile(fullPath)
|
||||
except IOError, e:
|
||||
if e.errno == errno.EACCES:
|
||||
# Not allowed, let's switch to root
|
||||
key = runAsEffectiveUser(0, 0, keys.Key.fromFile, fullPath)
|
||||
keyType = keys.objectType(key.keyObject)
|
||||
privateKeys[keyType] = key
|
||||
else:
|
||||
raise
|
||||
except Exception, e:
|
||||
log.msg('bad private key file %s: %s' % (filename, e))
|
||||
else:
|
||||
keyType = keys.objectType(key.keyObject)
|
||||
privateKeys[keyType] = key
|
||||
return privateKeys
|
||||
|
||||
|
||||
def getPrimes(self):
|
||||
try:
|
||||
return primes.parseModuliFile(self.moduliRoot+'/moduli')
|
||||
except IOError:
|
||||
return None
|
@ -0,0 +1,26 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
|
||||
"""
|
||||
Parsing for the moduli file, which contains Diffie-Hellman prime groups.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
def parseModuliFile(filename):
|
||||
lines = open(filename).readlines()
|
||||
primes = {}
|
||||
for l in lines:
|
||||
l = l.strip()
|
||||
if not l or l[0]=='#':
|
||||
continue
|
||||
tim, typ, tst, tri, size, gen, mod = l.split()
|
||||
size = int(size) + 1
|
||||
gen = long(gen)
|
||||
mod = long(mod, 16)
|
||||
if not primes.has_key(size):
|
||||
primes[size] = []
|
||||
primes[size].append((gen, mod))
|
||||
return primes
|
@ -0,0 +1,331 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_recvline -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Basic line editing support.
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
import string
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.conch.insults import insults, helper
|
||||
|
||||
from twisted.python import log, reflect
|
||||
|
||||
_counters = {}
|
||||
class Logging(object):
|
||||
"""Wrapper which logs attribute lookups.
|
||||
|
||||
This was useful in debugging something, I guess. I forget what.
|
||||
It can probably be deleted or moved somewhere more appropriate.
|
||||
Nothing special going on here, really.
|
||||
"""
|
||||
def __init__(self, original):
|
||||
self.original = original
|
||||
key = reflect.qual(original.__class__)
|
||||
count = _counters.get(key, 0)
|
||||
_counters[key] = count + 1
|
||||
self._logFile = file(key + '-' + str(count), 'w')
|
||||
|
||||
def __str__(self):
|
||||
return str(super(Logging, self).__getattribute__('original'))
|
||||
|
||||
def __repr__(self):
|
||||
return repr(super(Logging, self).__getattribute__('original'))
|
||||
|
||||
def __getattribute__(self, name):
|
||||
original = super(Logging, self).__getattribute__('original')
|
||||
logFile = super(Logging, self).__getattribute__('_logFile')
|
||||
logFile.write(name + '\n')
|
||||
return getattr(original, name)
|
||||
|
||||
|
||||
|
||||
@implementer(insults.ITerminalTransport)
|
||||
class TransportSequence(object):
|
||||
"""An L{ITerminalTransport} implementation which forwards calls to
|
||||
one or more other L{ITerminalTransport}s.
|
||||
|
||||
This is a cheap way for servers to keep track of the state they
|
||||
expect the client to see, since all terminal manipulations can be
|
||||
send to the real client and to a terminal emulator that lives in
|
||||
the server process.
|
||||
"""
|
||||
|
||||
for keyID in ('UP_ARROW', 'DOWN_ARROW', 'RIGHT_ARROW', 'LEFT_ARROW',
|
||||
'HOME', 'INSERT', 'DELETE', 'END', 'PGUP', 'PGDN',
|
||||
'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'F9',
|
||||
'F10', 'F11', 'F12'):
|
||||
exec '%s = object()' % (keyID,)
|
||||
|
||||
TAB = '\t'
|
||||
BACKSPACE = '\x7f'
|
||||
|
||||
def __init__(self, *transports):
|
||||
assert transports, "Cannot construct a TransportSequence with no transports"
|
||||
self.transports = transports
|
||||
|
||||
for method in insults.ITerminalTransport:
|
||||
exec """\
|
||||
def %s(self, *a, **kw):
|
||||
for tpt in self.transports:
|
||||
result = tpt.%s(*a, **kw)
|
||||
return result
|
||||
""" % (method, method)
|
||||
|
||||
class LocalTerminalBufferMixin(object):
|
||||
"""A mixin for RecvLine subclasses which records the state of the terminal.
|
||||
|
||||
This is accomplished by performing all L{ITerminalTransport} operations on both
|
||||
the transport passed to makeConnection and an instance of helper.TerminalBuffer.
|
||||
|
||||
@ivar terminalCopy: A L{helper.TerminalBuffer} instance which efforts
|
||||
will be made to keep up to date with the actual terminal
|
||||
associated with this protocol instance.
|
||||
"""
|
||||
|
||||
def makeConnection(self, transport):
|
||||
self.terminalCopy = helper.TerminalBuffer()
|
||||
self.terminalCopy.connectionMade()
|
||||
return super(LocalTerminalBufferMixin, self).makeConnection(
|
||||
TransportSequence(transport, self.terminalCopy))
|
||||
|
||||
def __str__(self):
|
||||
return str(self.terminalCopy)
|
||||
|
||||
class RecvLine(insults.TerminalProtocol):
|
||||
"""L{TerminalProtocol} which adds line editing features.
|
||||
|
||||
Clients will be prompted for lines of input with all the usual
|
||||
features: character echoing, left and right arrow support for
|
||||
moving the cursor to different areas of the line buffer, backspace
|
||||
and delete for removing characters, and insert for toggling
|
||||
between typeover and insert mode. Tabs will be expanded to enough
|
||||
spaces to move the cursor to the next tabstop (every four
|
||||
characters by default). Enter causes the line buffer to be
|
||||
cleared and the line to be passed to the lineReceived() method
|
||||
which, by default, does nothing. Subclasses are responsible for
|
||||
redrawing the input prompt (this will probably change).
|
||||
"""
|
||||
width = 80
|
||||
height = 24
|
||||
|
||||
TABSTOP = 4
|
||||
|
||||
ps = ('>>> ', '... ')
|
||||
pn = 0
|
||||
_printableChars = set(string.printable)
|
||||
|
||||
def connectionMade(self):
|
||||
# A list containing the characters making up the current line
|
||||
self.lineBuffer = []
|
||||
|
||||
# A zero-based (wtf else?) index into self.lineBuffer.
|
||||
# Indicates the current cursor position.
|
||||
self.lineBufferIndex = 0
|
||||
|
||||
t = self.terminal
|
||||
# A map of keyIDs to bound instance methods.
|
||||
self.keyHandlers = {
|
||||
t.LEFT_ARROW: self.handle_LEFT,
|
||||
t.RIGHT_ARROW: self.handle_RIGHT,
|
||||
t.TAB: self.handle_TAB,
|
||||
|
||||
# Both of these should not be necessary, but figuring out
|
||||
# which is necessary is a huge hassle.
|
||||
'\r': self.handle_RETURN,
|
||||
'\n': self.handle_RETURN,
|
||||
|
||||
t.BACKSPACE: self.handle_BACKSPACE,
|
||||
t.DELETE: self.handle_DELETE,
|
||||
t.INSERT: self.handle_INSERT,
|
||||
t.HOME: self.handle_HOME,
|
||||
t.END: self.handle_END}
|
||||
|
||||
self.initializeScreen()
|
||||
|
||||
def initializeScreen(self):
|
||||
# Hmm, state sucks. Oh well.
|
||||
# For now we will just take over the whole terminal.
|
||||
self.terminal.reset()
|
||||
self.terminal.write(self.ps[self.pn])
|
||||
# XXX Note: I would prefer to default to starting in insert
|
||||
# mode, however this does not seem to actually work! I do not
|
||||
# know why. This is probably of interest to implementors
|
||||
# subclassing RecvLine.
|
||||
|
||||
# XXX XXX Note: But the unit tests all expect the initial mode
|
||||
# to be insert right now. Fuck, there needs to be a way to
|
||||
# query the current mode or something.
|
||||
# self.setTypeoverMode()
|
||||
self.setInsertMode()
|
||||
|
||||
def currentLineBuffer(self):
|
||||
s = ''.join(self.lineBuffer)
|
||||
return s[:self.lineBufferIndex], s[self.lineBufferIndex:]
|
||||
|
||||
def setInsertMode(self):
|
||||
self.mode = 'insert'
|
||||
self.terminal.setModes([insults.modes.IRM])
|
||||
|
||||
def setTypeoverMode(self):
|
||||
self.mode = 'typeover'
|
||||
self.terminal.resetModes([insults.modes.IRM])
|
||||
|
||||
def drawInputLine(self):
|
||||
"""
|
||||
Write a line containing the current input prompt and the current line
|
||||
buffer at the current cursor position.
|
||||
"""
|
||||
self.terminal.write(self.ps[self.pn] + ''.join(self.lineBuffer))
|
||||
|
||||
def terminalSize(self, width, height):
|
||||
# XXX - Clear the previous input line, redraw it at the new
|
||||
# cursor position
|
||||
self.terminal.eraseDisplay()
|
||||
self.terminal.cursorHome()
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.drawInputLine()
|
||||
|
||||
def unhandledControlSequence(self, seq):
|
||||
pass
|
||||
|
||||
def keystrokeReceived(self, keyID, modifier):
|
||||
m = self.keyHandlers.get(keyID)
|
||||
if m is not None:
|
||||
m()
|
||||
elif keyID in self._printableChars:
|
||||
self.characterReceived(keyID, False)
|
||||
else:
|
||||
log.msg("Received unhandled keyID: %r" % (keyID,))
|
||||
|
||||
def characterReceived(self, ch, moreCharactersComing):
|
||||
if self.mode == 'insert':
|
||||
self.lineBuffer.insert(self.lineBufferIndex, ch)
|
||||
else:
|
||||
self.lineBuffer[self.lineBufferIndex:self.lineBufferIndex+1] = [ch]
|
||||
self.lineBufferIndex += 1
|
||||
self.terminal.write(ch)
|
||||
|
||||
def handle_TAB(self):
|
||||
n = self.TABSTOP - (len(self.lineBuffer) % self.TABSTOP)
|
||||
self.terminal.cursorForward(n)
|
||||
self.lineBufferIndex += n
|
||||
self.lineBuffer.extend(' ' * n)
|
||||
|
||||
def handle_LEFT(self):
|
||||
if self.lineBufferIndex > 0:
|
||||
self.lineBufferIndex -= 1
|
||||
self.terminal.cursorBackward()
|
||||
|
||||
def handle_RIGHT(self):
|
||||
if self.lineBufferIndex < len(self.lineBuffer):
|
||||
self.lineBufferIndex += 1
|
||||
self.terminal.cursorForward()
|
||||
|
||||
def handle_HOME(self):
|
||||
if self.lineBufferIndex:
|
||||
self.terminal.cursorBackward(self.lineBufferIndex)
|
||||
self.lineBufferIndex = 0
|
||||
|
||||
def handle_END(self):
|
||||
offset = len(self.lineBuffer) - self.lineBufferIndex
|
||||
if offset:
|
||||
self.terminal.cursorForward(offset)
|
||||
self.lineBufferIndex = len(self.lineBuffer)
|
||||
|
||||
def handle_BACKSPACE(self):
|
||||
if self.lineBufferIndex > 0:
|
||||
self.lineBufferIndex -= 1
|
||||
del self.lineBuffer[self.lineBufferIndex]
|
||||
self.terminal.cursorBackward()
|
||||
self.terminal.deleteCharacter()
|
||||
|
||||
def handle_DELETE(self):
|
||||
if self.lineBufferIndex < len(self.lineBuffer):
|
||||
del self.lineBuffer[self.lineBufferIndex]
|
||||
self.terminal.deleteCharacter()
|
||||
|
||||
def handle_RETURN(self):
|
||||
line = ''.join(self.lineBuffer)
|
||||
self.lineBuffer = []
|
||||
self.lineBufferIndex = 0
|
||||
self.terminal.nextLine()
|
||||
self.lineReceived(line)
|
||||
|
||||
def handle_INSERT(self):
|
||||
assert self.mode in ('typeover', 'insert')
|
||||
if self.mode == 'typeover':
|
||||
self.setInsertMode()
|
||||
else:
|
||||
self.setTypeoverMode()
|
||||
|
||||
def lineReceived(self, line):
|
||||
pass
|
||||
|
||||
class HistoricRecvLine(RecvLine):
|
||||
"""L{TerminalProtocol} which adds both basic line-editing features and input history.
|
||||
|
||||
Everything supported by L{RecvLine} is also supported by this class. In addition, the
|
||||
up and down arrows traverse the input history. Each received line is automatically
|
||||
added to the end of the input history.
|
||||
"""
|
||||
def connectionMade(self):
|
||||
RecvLine.connectionMade(self)
|
||||
|
||||
self.historyLines = []
|
||||
self.historyPosition = 0
|
||||
|
||||
t = self.terminal
|
||||
self.keyHandlers.update({t.UP_ARROW: self.handle_UP,
|
||||
t.DOWN_ARROW: self.handle_DOWN})
|
||||
|
||||
def currentHistoryBuffer(self):
|
||||
b = tuple(self.historyLines)
|
||||
return b[:self.historyPosition], b[self.historyPosition:]
|
||||
|
||||
def _deliverBuffer(self, buf):
|
||||
if buf:
|
||||
for ch in buf[:-1]:
|
||||
self.characterReceived(ch, True)
|
||||
self.characterReceived(buf[-1], False)
|
||||
|
||||
def handle_UP(self):
|
||||
if self.lineBuffer and self.historyPosition == len(self.historyLines):
|
||||
self.historyLines.append(self.lineBuffer)
|
||||
if self.historyPosition > 0:
|
||||
self.handle_HOME()
|
||||
self.terminal.eraseToLineEnd()
|
||||
|
||||
self.historyPosition -= 1
|
||||
self.lineBuffer = []
|
||||
|
||||
self._deliverBuffer(self.historyLines[self.historyPosition])
|
||||
|
||||
def handle_DOWN(self):
|
||||
if self.historyPosition < len(self.historyLines) - 1:
|
||||
self.handle_HOME()
|
||||
self.terminal.eraseToLineEnd()
|
||||
|
||||
self.historyPosition += 1
|
||||
self.lineBuffer = []
|
||||
|
||||
self._deliverBuffer(self.historyLines[self.historyPosition])
|
||||
else:
|
||||
self.handle_HOME()
|
||||
self.terminal.eraseToLineEnd()
|
||||
|
||||
self.historyPosition = len(self.historyLines)
|
||||
self.lineBuffer = []
|
||||
self.lineBufferIndex = 0
|
||||
|
||||
def handle_RETURN(self):
|
||||
if self.lineBuffer:
|
||||
self.historyLines.append(''.join(self.lineBuffer))
|
||||
self.historyPosition = len(self.historyLines)
|
||||
return RecvLine.handle_RETURN(self)
|
@ -0,0 +1 @@
|
||||
'conch scripts'
|
@ -0,0 +1,834 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_cftp -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation module for the I{cftp} command.
|
||||
"""
|
||||
|
||||
import os, sys, getpass, struct, tty, fcntl, stat
|
||||
import fnmatch, pwd, glob
|
||||
|
||||
from twisted.conch.client import connect, default, options
|
||||
from twisted.conch.ssh import connection, common
|
||||
from twisted.conch.ssh import channel, filetransfer
|
||||
from twisted.protocols import basic
|
||||
from twisted.internet import reactor, stdio, defer, utils
|
||||
from twisted.python import log, usage, failure
|
||||
|
||||
class ClientOptions(options.ConchOptions):
|
||||
|
||||
synopsis = """Usage: cftp [options] [user@]host
|
||||
cftp [options] [user@]host[:dir[/]]
|
||||
cftp [options] [user@]host[:file [localfile]]
|
||||
"""
|
||||
longdesc = ("cftp is a client for logging into a remote machine and "
|
||||
"executing commands to send and receive file information")
|
||||
|
||||
optParameters = [
|
||||
['buffersize', 'B', 32768, 'Size of the buffer to use for sending/receiving.'],
|
||||
['batchfile', 'b', None, 'File to read commands from, or \'-\' for stdin.'],
|
||||
['requests', 'R', 5, 'Number of requests to make before waiting for a reply.'],
|
||||
['subsystem', 's', 'sftp', 'Subsystem/server program to connect to.']]
|
||||
|
||||
compData = usage.Completions(
|
||||
descriptions={
|
||||
"buffersize": "Size of send/receive buffer (default: 32768)"},
|
||||
extraActions=[usage.CompleteUserAtHost(),
|
||||
usage.CompleteFiles(descr="local file")])
|
||||
|
||||
def parseArgs(self, host, localPath=None):
|
||||
self['remotePath'] = ''
|
||||
if ':' in host:
|
||||
host, self['remotePath'] = host.split(':', 1)
|
||||
self['remotePath'].rstrip('/')
|
||||
self['host'] = host
|
||||
self['localPath'] = localPath
|
||||
|
||||
def run():
|
||||
# import hotshot
|
||||
# prof = hotshot.Profile('cftp.prof')
|
||||
# prof.start()
|
||||
args = sys.argv[1:]
|
||||
if '-l' in args: # cvs is an idiot
|
||||
i = args.index('-l')
|
||||
args = args[i:i+2]+args
|
||||
del args[i+2:i+4]
|
||||
options = ClientOptions()
|
||||
try:
|
||||
options.parseOptions(args)
|
||||
except usage.UsageError, u:
|
||||
print 'ERROR: %s' % u
|
||||
sys.exit(1)
|
||||
if options['log']:
|
||||
realout = sys.stdout
|
||||
log.startLogging(sys.stderr)
|
||||
sys.stdout = realout
|
||||
else:
|
||||
log.discardLogs()
|
||||
doConnect(options)
|
||||
reactor.run()
|
||||
# prof.stop()
|
||||
# prof.close()
|
||||
|
||||
def handleError():
|
||||
global exitStatus
|
||||
exitStatus = 2
|
||||
try:
|
||||
reactor.stop()
|
||||
except: pass
|
||||
log.err(failure.Failure())
|
||||
raise
|
||||
|
||||
def doConnect(options):
|
||||
# log.deferr = handleError # HACK
|
||||
if '@' in options['host']:
|
||||
options['user'], options['host'] = options['host'].split('@',1)
|
||||
host = options['host']
|
||||
if not options['user']:
|
||||
options['user'] = getpass.getuser()
|
||||
if not options['port']:
|
||||
options['port'] = 22
|
||||
else:
|
||||
options['port'] = int(options['port'])
|
||||
host = options['host']
|
||||
port = options['port']
|
||||
conn = SSHConnection()
|
||||
conn.options = options
|
||||
vhk = default.verifyHostKey
|
||||
uao = default.SSHUserAuthClient(options['user'], options, conn)
|
||||
connect.connect(host, port, options, vhk, uao).addErrback(_ebExit)
|
||||
|
||||
def _ebExit(f):
|
||||
#global exitStatus
|
||||
if hasattr(f.value, 'value'):
|
||||
s = f.value.value
|
||||
else:
|
||||
s = str(f)
|
||||
print s
|
||||
#exitStatus = "conch: exiting with error %s" % f
|
||||
try:
|
||||
reactor.stop()
|
||||
except: pass
|
||||
|
||||
def _ignore(*args): pass
|
||||
|
||||
class FileWrapper:
|
||||
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
self.total = 0.0
|
||||
f.seek(0, 2) # seek to the end
|
||||
self.size = f.tell()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.f, attr)
|
||||
|
||||
class StdioClient(basic.LineReceiver):
|
||||
|
||||
_pwd = pwd
|
||||
|
||||
ps = 'cftp> '
|
||||
delimiter = '\n'
|
||||
|
||||
reactor = reactor
|
||||
|
||||
def __init__(self, client, f = None):
|
||||
self.client = client
|
||||
self.currentDirectory = ''
|
||||
self.file = f
|
||||
self.useProgressBar = (not f and 1) or 0
|
||||
|
||||
def connectionMade(self):
|
||||
self.client.realPath('').addCallback(self._cbSetCurDir)
|
||||
|
||||
def _cbSetCurDir(self, path):
|
||||
self.currentDirectory = path
|
||||
self._newLine()
|
||||
|
||||
def lineReceived(self, line):
|
||||
if self.client.transport.localClosed:
|
||||
return
|
||||
log.msg('got line %s' % repr(line))
|
||||
line = line.lstrip()
|
||||
if not line:
|
||||
self._newLine()
|
||||
return
|
||||
if self.file and line.startswith('-'):
|
||||
self.ignoreErrors = 1
|
||||
line = line[1:]
|
||||
else:
|
||||
self.ignoreErrors = 0
|
||||
d = self._dispatchCommand(line)
|
||||
if d is not None:
|
||||
d.addCallback(self._cbCommand)
|
||||
d.addErrback(self._ebCommand)
|
||||
|
||||
|
||||
def _dispatchCommand(self, line):
|
||||
if ' ' in line:
|
||||
command, rest = line.split(' ', 1)
|
||||
rest = rest.lstrip()
|
||||
else:
|
||||
command, rest = line, ''
|
||||
if command.startswith('!'): # command
|
||||
f = self.cmd_EXEC
|
||||
rest = (command[1:] + ' ' + rest).strip()
|
||||
else:
|
||||
command = command.upper()
|
||||
log.msg('looking up cmd %s' % command)
|
||||
f = getattr(self, 'cmd_%s' % command, None)
|
||||
if f is not None:
|
||||
return defer.maybeDeferred(f, rest)
|
||||
else:
|
||||
self._ebCommand(failure.Failure(NotImplementedError(
|
||||
"No command called `%s'" % command)))
|
||||
self._newLine()
|
||||
|
||||
def _printFailure(self, f):
|
||||
log.msg(f)
|
||||
e = f.trap(NotImplementedError, filetransfer.SFTPError, OSError, IOError)
|
||||
if e == NotImplementedError:
|
||||
self.transport.write(self.cmd_HELP(''))
|
||||
elif e == filetransfer.SFTPError:
|
||||
self.transport.write("remote error %i: %s\n" %
|
||||
(f.value.code, f.value.message))
|
||||
elif e in (OSError, IOError):
|
||||
self.transport.write("local error %i: %s\n" %
|
||||
(f.value.errno, f.value.strerror))
|
||||
|
||||
def _newLine(self):
|
||||
if self.client.transport.localClosed:
|
||||
return
|
||||
self.transport.write(self.ps)
|
||||
self.ignoreErrors = 0
|
||||
if self.file:
|
||||
l = self.file.readline()
|
||||
if not l:
|
||||
self.client.transport.loseConnection()
|
||||
else:
|
||||
self.transport.write(l)
|
||||
self.lineReceived(l.strip())
|
||||
|
||||
def _cbCommand(self, result):
|
||||
if result is not None:
|
||||
self.transport.write(result)
|
||||
if not result.endswith('\n'):
|
||||
self.transport.write('\n')
|
||||
self._newLine()
|
||||
|
||||
def _ebCommand(self, f):
|
||||
self._printFailure(f)
|
||||
if self.file and not self.ignoreErrors:
|
||||
self.client.transport.loseConnection()
|
||||
self._newLine()
|
||||
|
||||
def cmd_CD(self, path):
|
||||
path, rest = self._getFilename(path)
|
||||
if not path.endswith('/'):
|
||||
path += '/'
|
||||
newPath = path and os.path.join(self.currentDirectory, path) or ''
|
||||
d = self.client.openDirectory(newPath)
|
||||
d.addCallback(self._cbCd)
|
||||
d.addErrback(self._ebCommand)
|
||||
return d
|
||||
|
||||
def _cbCd(self, directory):
|
||||
directory.close()
|
||||
d = self.client.realPath(directory.name)
|
||||
d.addCallback(self._cbCurDir)
|
||||
return d
|
||||
|
||||
def _cbCurDir(self, path):
|
||||
self.currentDirectory = path
|
||||
|
||||
def cmd_CHGRP(self, rest):
|
||||
grp, rest = rest.split(None, 1)
|
||||
path, rest = self._getFilename(rest)
|
||||
grp = int(grp)
|
||||
d = self.client.getAttrs(path)
|
||||
d.addCallback(self._cbSetUsrGrp, path, grp=grp)
|
||||
return d
|
||||
|
||||
def cmd_CHMOD(self, rest):
|
||||
mod, rest = rest.split(None, 1)
|
||||
path, rest = self._getFilename(rest)
|
||||
mod = int(mod, 8)
|
||||
d = self.client.setAttrs(path, {'permissions':mod})
|
||||
d.addCallback(_ignore)
|
||||
return d
|
||||
|
||||
def cmd_CHOWN(self, rest):
|
||||
usr, rest = rest.split(None, 1)
|
||||
path, rest = self._getFilename(rest)
|
||||
usr = int(usr)
|
||||
d = self.client.getAttrs(path)
|
||||
d.addCallback(self._cbSetUsrGrp, path, usr=usr)
|
||||
return d
|
||||
|
||||
def _cbSetUsrGrp(self, attrs, path, usr=None, grp=None):
|
||||
new = {}
|
||||
new['uid'] = (usr is not None) and usr or attrs['uid']
|
||||
new['gid'] = (grp is not None) and grp or attrs['gid']
|
||||
d = self.client.setAttrs(path, new)
|
||||
d.addCallback(_ignore)
|
||||
return d
|
||||
|
||||
def cmd_GET(self, rest):
|
||||
remote, rest = self._getFilename(rest)
|
||||
if '*' in remote or '?' in remote: # wildcard
|
||||
if rest:
|
||||
local, rest = self._getFilename(rest)
|
||||
if not os.path.isdir(local):
|
||||
return "Wildcard get with non-directory target."
|
||||
else:
|
||||
local = ''
|
||||
d = self._remoteGlob(remote)
|
||||
d.addCallback(self._cbGetMultiple, local)
|
||||
return d
|
||||
if rest:
|
||||
local, rest = self._getFilename(rest)
|
||||
else:
|
||||
local = os.path.split(remote)[1]
|
||||
log.msg((remote, local))
|
||||
lf = file(local, 'w', 0)
|
||||
path = os.path.join(self.currentDirectory, remote)
|
||||
d = self.client.openFile(path, filetransfer.FXF_READ, {})
|
||||
d.addCallback(self._cbGetOpenFile, lf)
|
||||
d.addErrback(self._ebCloseLf, lf)
|
||||
return d
|
||||
|
||||
def _cbGetMultiple(self, files, local):
|
||||
#if self._useProgressBar: # one at a time
|
||||
# XXX this can be optimized for times w/o progress bar
|
||||
return self._cbGetMultipleNext(None, files, local)
|
||||
|
||||
def _cbGetMultipleNext(self, res, files, local):
|
||||
if isinstance(res, failure.Failure):
|
||||
self._printFailure(res)
|
||||
elif res:
|
||||
self.transport.write(res)
|
||||
if not res.endswith('\n'):
|
||||
self.transport.write('\n')
|
||||
if not files:
|
||||
return
|
||||
f = files.pop(0)[0]
|
||||
lf = file(os.path.join(local, os.path.split(f)[1]), 'w', 0)
|
||||
path = os.path.join(self.currentDirectory, f)
|
||||
d = self.client.openFile(path, filetransfer.FXF_READ, {})
|
||||
d.addCallback(self._cbGetOpenFile, lf)
|
||||
d.addErrback(self._ebCloseLf, lf)
|
||||
d.addBoth(self._cbGetMultipleNext, files, local)
|
||||
return d
|
||||
|
||||
def _ebCloseLf(self, f, lf):
|
||||
lf.close()
|
||||
return f
|
||||
|
||||
def _cbGetOpenFile(self, rf, lf):
|
||||
return rf.getAttrs().addCallback(self._cbGetFileSize, rf, lf)
|
||||
|
||||
def _cbGetFileSize(self, attrs, rf, lf):
|
||||
if not stat.S_ISREG(attrs['permissions']):
|
||||
rf.close()
|
||||
lf.close()
|
||||
return "Can't get non-regular file: %s" % rf.name
|
||||
rf.size = attrs['size']
|
||||
bufferSize = self.client.transport.conn.options['buffersize']
|
||||
numRequests = self.client.transport.conn.options['requests']
|
||||
rf.total = 0.0
|
||||
dList = []
|
||||
chunks = []
|
||||
startTime = self.reactor.seconds()
|
||||
for i in range(numRequests):
|
||||
d = self._cbGetRead('', rf, lf, chunks, 0, bufferSize, startTime)
|
||||
dList.append(d)
|
||||
dl = defer.DeferredList(dList, fireOnOneErrback=1)
|
||||
dl.addCallback(self._cbGetDone, rf, lf)
|
||||
return dl
|
||||
|
||||
def _getNextChunk(self, chunks):
|
||||
end = 0
|
||||
for chunk in chunks:
|
||||
if end == 'eof':
|
||||
return # nothing more to get
|
||||
if end != chunk[0]:
|
||||
i = chunks.index(chunk)
|
||||
chunks.insert(i, (end, chunk[0]))
|
||||
return (end, chunk[0] - end)
|
||||
end = chunk[1]
|
||||
bufSize = int(self.client.transport.conn.options['buffersize'])
|
||||
chunks.append((end, end + bufSize))
|
||||
return (end, bufSize)
|
||||
|
||||
def _cbGetRead(self, data, rf, lf, chunks, start, size, startTime):
|
||||
if data and isinstance(data, failure.Failure):
|
||||
log.msg('get read err: %s' % data)
|
||||
reason = data
|
||||
reason.trap(EOFError)
|
||||
i = chunks.index((start, start + size))
|
||||
del chunks[i]
|
||||
chunks.insert(i, (start, 'eof'))
|
||||
elif data:
|
||||
log.msg('get read data: %i' % len(data))
|
||||
lf.seek(start)
|
||||
lf.write(data)
|
||||
if len(data) != size:
|
||||
log.msg('got less than we asked for: %i < %i' %
|
||||
(len(data), size))
|
||||
i = chunks.index((start, start + size))
|
||||
del chunks[i]
|
||||
chunks.insert(i, (start, start + len(data)))
|
||||
rf.total += len(data)
|
||||
if self.useProgressBar:
|
||||
self._printProgressBar(rf, startTime)
|
||||
chunk = self._getNextChunk(chunks)
|
||||
if not chunk:
|
||||
return
|
||||
else:
|
||||
start, length = chunk
|
||||
log.msg('asking for %i -> %i' % (start, start+length))
|
||||
d = rf.readChunk(start, length)
|
||||
d.addBoth(self._cbGetRead, rf, lf, chunks, start, length, startTime)
|
||||
return d
|
||||
|
||||
def _cbGetDone(self, ignored, rf, lf):
|
||||
log.msg('get done')
|
||||
rf.close()
|
||||
lf.close()
|
||||
if self.useProgressBar:
|
||||
self.transport.write('\n')
|
||||
return "Transferred %s to %s" % (rf.name, lf.name)
|
||||
|
||||
def cmd_PUT(self, rest):
|
||||
local, rest = self._getFilename(rest)
|
||||
if '*' in local or '?' in local: # wildcard
|
||||
if rest:
|
||||
remote, rest = self._getFilename(rest)
|
||||
path = os.path.join(self.currentDirectory, remote)
|
||||
d = self.client.getAttrs(path)
|
||||
d.addCallback(self._cbPutTargetAttrs, remote, local)
|
||||
return d
|
||||
else:
|
||||
remote = ''
|
||||
files = glob.glob(local)
|
||||
return self._cbPutMultipleNext(None, files, remote)
|
||||
if rest:
|
||||
remote, rest = self._getFilename(rest)
|
||||
else:
|
||||
remote = os.path.split(local)[1]
|
||||
lf = file(local, 'r')
|
||||
path = os.path.join(self.currentDirectory, remote)
|
||||
flags = filetransfer.FXF_WRITE|filetransfer.FXF_CREAT|filetransfer.FXF_TRUNC
|
||||
d = self.client.openFile(path, flags, {})
|
||||
d.addCallback(self._cbPutOpenFile, lf)
|
||||
d.addErrback(self._ebCloseLf, lf)
|
||||
return d
|
||||
|
||||
def _cbPutTargetAttrs(self, attrs, path, local):
|
||||
if not stat.S_ISDIR(attrs['permissions']):
|
||||
return "Wildcard put with non-directory target."
|
||||
# FIXME:7037:
|
||||
# Check what `files` variable should do here.
|
||||
return self._cbPutMultipleNext(None, files, path)
|
||||
|
||||
def _cbPutMultipleNext(self, res, files, path):
|
||||
if isinstance(res, failure.Failure):
|
||||
self._printFailure(res)
|
||||
elif res:
|
||||
self.transport.write(res)
|
||||
if not res.endswith('\n'):
|
||||
self.transport.write('\n')
|
||||
f = None
|
||||
while files and not f:
|
||||
try:
|
||||
f = files.pop(0)
|
||||
lf = file(f, 'r')
|
||||
except:
|
||||
self._printFailure(failure.Failure())
|
||||
f = None
|
||||
if not f:
|
||||
return
|
||||
name = os.path.split(f)[1]
|
||||
remote = os.path.join(self.currentDirectory, path, name)
|
||||
log.msg((name, remote, path))
|
||||
flags = filetransfer.FXF_WRITE|filetransfer.FXF_CREAT|filetransfer.FXF_TRUNC
|
||||
d = self.client.openFile(remote, flags, {})
|
||||
d.addCallback(self._cbPutOpenFile, lf)
|
||||
d.addErrback(self._ebCloseLf, lf)
|
||||
d.addBoth(self._cbPutMultipleNext, files, path)
|
||||
return d
|
||||
|
||||
def _cbPutOpenFile(self, rf, lf):
|
||||
numRequests = self.client.transport.conn.options['requests']
|
||||
if self.useProgressBar:
|
||||
lf = FileWrapper(lf)
|
||||
dList = []
|
||||
chunks = []
|
||||
startTime = self.reactor.seconds()
|
||||
for i in range(numRequests):
|
||||
d = self._cbPutWrite(None, rf, lf, chunks, startTime)
|
||||
if d:
|
||||
dList.append(d)
|
||||
dl = defer.DeferredList(dList, fireOnOneErrback=1)
|
||||
dl.addCallback(self._cbPutDone, rf, lf)
|
||||
return dl
|
||||
|
||||
def _cbPutWrite(self, ignored, rf, lf, chunks, startTime):
|
||||
chunk = self._getNextChunk(chunks)
|
||||
start, size = chunk
|
||||
lf.seek(start)
|
||||
data = lf.read(size)
|
||||
if self.useProgressBar:
|
||||
lf.total += len(data)
|
||||
self._printProgressBar(lf, startTime)
|
||||
if data:
|
||||
d = rf.writeChunk(start, data)
|
||||
d.addCallback(self._cbPutWrite, rf, lf, chunks, startTime)
|
||||
return d
|
||||
else:
|
||||
return
|
||||
|
||||
def _cbPutDone(self, ignored, rf, lf):
|
||||
lf.close()
|
||||
rf.close()
|
||||
if self.useProgressBar:
|
||||
self.transport.write('\n')
|
||||
return 'Transferred %s to %s' % (lf.name, rf.name)
|
||||
|
||||
def cmd_LCD(self, path):
|
||||
os.chdir(path)
|
||||
|
||||
def cmd_LN(self, rest):
|
||||
linkpath, rest = self._getFilename(rest)
|
||||
targetpath, rest = self._getFilename(rest)
|
||||
linkpath, targetpath = map(
|
||||
lambda x: os.path.join(self.currentDirectory, x),
|
||||
(linkpath, targetpath))
|
||||
return self.client.makeLink(linkpath, targetpath).addCallback(_ignore)
|
||||
|
||||
def cmd_LS(self, rest):
|
||||
# possible lines:
|
||||
# ls current directory
|
||||
# ls name_of_file that file
|
||||
# ls name_of_directory that directory
|
||||
# ls some_glob_string current directory, globbed for that string
|
||||
options = []
|
||||
rest = rest.split()
|
||||
while rest and rest[0] and rest[0][0] == '-':
|
||||
opts = rest.pop(0)[1:]
|
||||
for o in opts:
|
||||
if o == 'l':
|
||||
options.append('verbose')
|
||||
elif o == 'a':
|
||||
options.append('all')
|
||||
rest = ' '.join(rest)
|
||||
path, rest = self._getFilename(rest)
|
||||
if not path:
|
||||
fullPath = self.currentDirectory + '/'
|
||||
else:
|
||||
fullPath = os.path.join(self.currentDirectory, path)
|
||||
d = self._remoteGlob(fullPath)
|
||||
d.addCallback(self._cbDisplayFiles, options)
|
||||
return d
|
||||
|
||||
def _cbDisplayFiles(self, files, options):
|
||||
files.sort()
|
||||
if 'all' not in options:
|
||||
files = [f for f in files if not f[0].startswith('.')]
|
||||
if 'verbose' in options:
|
||||
lines = [f[1] for f in files]
|
||||
else:
|
||||
lines = [f[0] for f in files]
|
||||
if not lines:
|
||||
return None
|
||||
else:
|
||||
return '\n'.join(lines)
|
||||
|
||||
def cmd_MKDIR(self, path):
|
||||
path, rest = self._getFilename(path)
|
||||
path = os.path.join(self.currentDirectory, path)
|
||||
return self.client.makeDirectory(path, {}).addCallback(_ignore)
|
||||
|
||||
def cmd_RMDIR(self, path):
|
||||
path, rest = self._getFilename(path)
|
||||
path = os.path.join(self.currentDirectory, path)
|
||||
return self.client.removeDirectory(path).addCallback(_ignore)
|
||||
|
||||
def cmd_LMKDIR(self, path):
|
||||
os.system("mkdir %s" % path)
|
||||
|
||||
def cmd_RM(self, path):
|
||||
path, rest = self._getFilename(path)
|
||||
path = os.path.join(self.currentDirectory, path)
|
||||
return self.client.removeFile(path).addCallback(_ignore)
|
||||
|
||||
def cmd_LLS(self, rest):
|
||||
os.system("ls %s" % rest)
|
||||
|
||||
def cmd_RENAME(self, rest):
|
||||
oldpath, rest = self._getFilename(rest)
|
||||
newpath, rest = self._getFilename(rest)
|
||||
oldpath, newpath = map (
|
||||
lambda x: os.path.join(self.currentDirectory, x),
|
||||
(oldpath, newpath))
|
||||
return self.client.renameFile(oldpath, newpath).addCallback(_ignore)
|
||||
|
||||
def cmd_EXIT(self, ignored):
|
||||
self.client.transport.loseConnection()
|
||||
|
||||
cmd_QUIT = cmd_EXIT
|
||||
|
||||
def cmd_VERSION(self, ignored):
|
||||
return "SFTP version %i" % self.client.version
|
||||
|
||||
def cmd_HELP(self, ignored):
|
||||
return """Available commands:
|
||||
cd path Change remote directory to 'path'.
|
||||
chgrp gid path Change gid of 'path' to 'gid'.
|
||||
chmod mode path Change mode of 'path' to 'mode'.
|
||||
chown uid path Change uid of 'path' to 'uid'.
|
||||
exit Disconnect from the server.
|
||||
get remote-path [local-path] Get remote file.
|
||||
help Get a list of available commands.
|
||||
lcd path Change local directory to 'path'.
|
||||
lls [ls-options] [path] Display local directory listing.
|
||||
lmkdir path Create local directory.
|
||||
ln linkpath targetpath Symlink remote file.
|
||||
lpwd Print the local working directory.
|
||||
ls [-l] [path] Display remote directory listing.
|
||||
mkdir path Create remote directory.
|
||||
progress Toggle progress bar.
|
||||
put local-path [remote-path] Put local file.
|
||||
pwd Print the remote working directory.
|
||||
quit Disconnect from the server.
|
||||
rename oldpath newpath Rename remote file.
|
||||
rmdir path Remove remote directory.
|
||||
rm path Remove remote file.
|
||||
version Print the SFTP version.
|
||||
? Synonym for 'help'.
|
||||
"""
|
||||
|
||||
def cmd_PWD(self, ignored):
|
||||
return self.currentDirectory
|
||||
|
||||
def cmd_LPWD(self, ignored):
|
||||
return os.getcwd()
|
||||
|
||||
def cmd_PROGRESS(self, ignored):
|
||||
self.useProgressBar = not self.useProgressBar
|
||||
return "%ssing progess bar." % (self.useProgressBar and "U" or "Not u")
|
||||
|
||||
def cmd_EXEC(self, rest):
|
||||
"""
|
||||
Run C{rest} using the user's shell (or /bin/sh if they do not have
|
||||
one).
|
||||
"""
|
||||
shell = self._pwd.getpwnam(getpass.getuser())[6]
|
||||
if not shell:
|
||||
shell = '/bin/sh'
|
||||
if rest:
|
||||
cmds = ['-c', rest]
|
||||
return utils.getProcessOutput(shell, cmds, errortoo=1)
|
||||
else:
|
||||
os.system(shell)
|
||||
|
||||
# accessory functions
|
||||
|
||||
def _remoteGlob(self, fullPath):
|
||||
log.msg('looking up %s' % fullPath)
|
||||
head, tail = os.path.split(fullPath)
|
||||
if '*' in tail or '?' in tail:
|
||||
glob = 1
|
||||
else:
|
||||
glob = 0
|
||||
if tail and not glob: # could be file or directory
|
||||
# try directory first
|
||||
d = self.client.openDirectory(fullPath)
|
||||
d.addCallback(self._cbOpenList, '')
|
||||
d.addErrback(self._ebNotADirectory, head, tail)
|
||||
else:
|
||||
d = self.client.openDirectory(head)
|
||||
d.addCallback(self._cbOpenList, tail)
|
||||
return d
|
||||
|
||||
def _cbOpenList(self, directory, glob):
|
||||
files = []
|
||||
d = directory.read()
|
||||
d.addBoth(self._cbReadFile, files, directory, glob)
|
||||
return d
|
||||
|
||||
def _ebNotADirectory(self, reason, path, glob):
|
||||
d = self.client.openDirectory(path)
|
||||
d.addCallback(self._cbOpenList, glob)
|
||||
return d
|
||||
|
||||
def _cbReadFile(self, files, l, directory, glob):
|
||||
if not isinstance(files, failure.Failure):
|
||||
if glob:
|
||||
l.extend([f for f in files if fnmatch.fnmatch(f[0], glob)])
|
||||
else:
|
||||
l.extend(files)
|
||||
d = directory.read()
|
||||
d.addBoth(self._cbReadFile, l, directory, glob)
|
||||
return d
|
||||
else:
|
||||
reason = files
|
||||
reason.trap(EOFError)
|
||||
directory.close()
|
||||
return l
|
||||
|
||||
def _abbrevSize(self, size):
|
||||
# from http://mail.python.org/pipermail/python-list/1999-December/018395.html
|
||||
_abbrevs = [
|
||||
(1<<50L, 'PB'),
|
||||
(1<<40L, 'TB'),
|
||||
(1<<30L, 'GB'),
|
||||
(1<<20L, 'MB'),
|
||||
(1<<10L, 'kB'),
|
||||
(1, 'B')
|
||||
]
|
||||
|
||||
for factor, suffix in _abbrevs:
|
||||
if size > factor:
|
||||
break
|
||||
return '%.1f' % (size/factor) + suffix
|
||||
|
||||
def _abbrevTime(self, t):
|
||||
if t > 3600: # 1 hour
|
||||
hours = int(t / 3600)
|
||||
t -= (3600 * hours)
|
||||
mins = int(t / 60)
|
||||
t -= (60 * mins)
|
||||
return "%i:%02i:%02i" % (hours, mins, t)
|
||||
else:
|
||||
mins = int(t/60)
|
||||
t -= (60 * mins)
|
||||
return "%02i:%02i" % (mins, t)
|
||||
|
||||
|
||||
def _printProgressBar(self, f, startTime):
|
||||
"""
|
||||
Update a console progress bar on this L{StdioClient}'s transport, based
|
||||
on the difference between the start time of the operation and the
|
||||
current time according to the reactor, and appropriate to the size of
|
||||
the console window.
|
||||
|
||||
@param f: a wrapper around the file which is being written or read
|
||||
@type f: L{FileWrapper}
|
||||
|
||||
@param startTime: The time at which the operation being tracked began.
|
||||
@type startTime: C{float}
|
||||
"""
|
||||
diff = self.reactor.seconds() - startTime
|
||||
total = f.total
|
||||
try:
|
||||
winSize = struct.unpack('4H',
|
||||
fcntl.ioctl(0, tty.TIOCGWINSZ, '12345679'))
|
||||
except IOError:
|
||||
winSize = [None, 80]
|
||||
if diff == 0.0:
|
||||
speed = 0.0
|
||||
else:
|
||||
speed = total / diff
|
||||
if speed:
|
||||
timeLeft = (f.size - total) / speed
|
||||
else:
|
||||
timeLeft = 0
|
||||
front = f.name
|
||||
back = '%3i%% %s %sps %s ' % ((total / f.size) * 100,
|
||||
self._abbrevSize(total),
|
||||
self._abbrevSize(speed),
|
||||
self._abbrevTime(timeLeft))
|
||||
spaces = (winSize[1] - (len(front) + len(back) + 1)) * ' '
|
||||
self.transport.write('\r%s%s%s' % (front, spaces, back))
|
||||
|
||||
|
||||
def _getFilename(self, line):
|
||||
line.lstrip()
|
||||
if not line:
|
||||
return None, ''
|
||||
if line[0] in '\'"':
|
||||
ret = []
|
||||
line = list(line)
|
||||
try:
|
||||
for i in range(1,len(line)):
|
||||
c = line[i]
|
||||
if c == line[0]:
|
||||
return ''.join(ret), ''.join(line[i+1:]).lstrip()
|
||||
elif c == '\\': # quoted character
|
||||
del line[i]
|
||||
if line[i] not in '\'"\\':
|
||||
raise IndexError, "bad quote: \\%s" % line[i]
|
||||
ret.append(line[i])
|
||||
else:
|
||||
ret.append(line[i])
|
||||
except IndexError:
|
||||
raise IndexError, "unterminated quote"
|
||||
ret = line.split(None, 1)
|
||||
if len(ret) == 1:
|
||||
return ret[0], ''
|
||||
else:
|
||||
return ret
|
||||
|
||||
StdioClient.__dict__['cmd_?'] = StdioClient.cmd_HELP
|
||||
|
||||
class SSHConnection(connection.SSHConnection):
|
||||
def serviceStarted(self):
|
||||
self.openChannel(SSHSession())
|
||||
|
||||
class SSHSession(channel.SSHChannel):
|
||||
|
||||
name = 'session'
|
||||
|
||||
def channelOpen(self, foo):
|
||||
log.msg('session %s open' % self.id)
|
||||
if self.conn.options['subsystem'].startswith('/'):
|
||||
request = 'exec'
|
||||
else:
|
||||
request = 'subsystem'
|
||||
d = self.conn.sendRequest(self, request, \
|
||||
common.NS(self.conn.options['subsystem']), wantReply=1)
|
||||
d.addCallback(self._cbSubsystem)
|
||||
d.addErrback(_ebExit)
|
||||
|
||||
def _cbSubsystem(self, result):
|
||||
self.client = filetransfer.FileTransferClient()
|
||||
self.client.makeConnection(self)
|
||||
self.dataReceived = self.client.dataReceived
|
||||
f = None
|
||||
if self.conn.options['batchfile']:
|
||||
fn = self.conn.options['batchfile']
|
||||
if fn != '-':
|
||||
f = file(fn)
|
||||
self.stdio = stdio.StandardIO(StdioClient(self.client, f))
|
||||
|
||||
def extReceived(self, t, data):
|
||||
if t==connection.EXTENDED_DATA_STDERR:
|
||||
log.msg('got %s stderr data' % len(data))
|
||||
sys.stderr.write(data)
|
||||
sys.stderr.flush()
|
||||
|
||||
def eofReceived(self):
|
||||
log.msg('got eof')
|
||||
self.stdio.loseWriteConnection()
|
||||
|
||||
def closeReceived(self):
|
||||
log.msg('remote side closed %s' % self)
|
||||
self.conn.sendClose(self)
|
||||
|
||||
def closed(self):
|
||||
try:
|
||||
reactor.stop()
|
||||
except:
|
||||
pass
|
||||
|
||||
def stopWriting(self):
|
||||
self.stdio.pauseProducing()
|
||||
|
||||
def startWriting(self):
|
||||
self.stdio.resumeProducing()
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
@ -0,0 +1,223 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_ckeygen -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation module for the `ckeygen` command.
|
||||
"""
|
||||
|
||||
import sys, os, getpass, socket
|
||||
if getpass.getpass == getpass.unix_getpass:
|
||||
try:
|
||||
import termios # hack around broken termios
|
||||
termios.tcgetattr, termios.tcsetattr
|
||||
except (ImportError, AttributeError):
|
||||
sys.modules['termios'] = None
|
||||
reload(getpass)
|
||||
|
||||
from twisted.conch.ssh import keys
|
||||
from twisted.python import failure, filepath, log, usage, randbytes
|
||||
|
||||
|
||||
|
||||
class GeneralOptions(usage.Options):
|
||||
synopsis = """Usage: ckeygen [options]
|
||||
"""
|
||||
|
||||
longdesc = "ckeygen manipulates public/private keys in various ways."
|
||||
|
||||
optParameters = [['bits', 'b', 1024, 'Number of bits in the key to create.'],
|
||||
['filename', 'f', None, 'Filename of the key file.'],
|
||||
['type', 't', None, 'Specify type of key to create.'],
|
||||
['comment', 'C', None, 'Provide new comment.'],
|
||||
['newpass', 'N', None, 'Provide new passphrase.'],
|
||||
['pass', 'P', None, 'Provide old passphrase.']]
|
||||
|
||||
optFlags = [['fingerprint', 'l', 'Show fingerprint of key file.'],
|
||||
['changepass', 'p', 'Change passphrase of private key file.'],
|
||||
['quiet', 'q', 'Quiet.'],
|
||||
['no-passphrase', None, "Create the key with no passphrase."],
|
||||
['showpub', 'y', 'Read private key file and print public key.']]
|
||||
|
||||
compData = usage.Completions(
|
||||
optActions={"type": usage.CompleteList(["rsa", "dsa"])})
|
||||
|
||||
|
||||
|
||||
def run():
|
||||
options = GeneralOptions()
|
||||
try:
|
||||
options.parseOptions(sys.argv[1:])
|
||||
except usage.UsageError, u:
|
||||
print 'ERROR: %s' % u
|
||||
options.opt_help()
|
||||
sys.exit(1)
|
||||
log.discardLogs()
|
||||
log.deferr = handleError # HACK
|
||||
if options['type']:
|
||||
if options['type'] == 'rsa':
|
||||
generateRSAkey(options)
|
||||
elif options['type'] == 'dsa':
|
||||
generateDSAkey(options)
|
||||
else:
|
||||
sys.exit('Key type was %s, must be one of: rsa, dsa' % options['type'])
|
||||
elif options['fingerprint']:
|
||||
printFingerprint(options)
|
||||
elif options['changepass']:
|
||||
changePassPhrase(options)
|
||||
elif options['showpub']:
|
||||
displayPublicKey(options)
|
||||
else:
|
||||
options.opt_help()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
def handleError():
|
||||
global exitStatus
|
||||
exitStatus = 2
|
||||
log.err(failure.Failure())
|
||||
raise
|
||||
|
||||
|
||||
|
||||
def generateRSAkey(options):
|
||||
from Crypto.PublicKey import RSA
|
||||
print 'Generating public/private rsa key pair.'
|
||||
key = RSA.generate(int(options['bits']), randbytes.secureRandom)
|
||||
_saveKey(key, options)
|
||||
|
||||
|
||||
|
||||
def generateDSAkey(options):
|
||||
from Crypto.PublicKey import DSA
|
||||
print 'Generating public/private dsa key pair.'
|
||||
key = DSA.generate(int(options['bits']), randbytes.secureRandom)
|
||||
_saveKey(key, options)
|
||||
|
||||
|
||||
|
||||
def printFingerprint(options):
|
||||
if not options['filename']:
|
||||
filename = os.path.expanduser('~/.ssh/id_rsa')
|
||||
options['filename'] = raw_input('Enter file in which the key is (%s): ' % filename)
|
||||
if os.path.exists(options['filename']+'.pub'):
|
||||
options['filename'] += '.pub'
|
||||
try:
|
||||
key = keys.Key.fromFile(options['filename'])
|
||||
obj = key.keyObject
|
||||
print '%s %s %s' % (
|
||||
obj.size() + 1,
|
||||
key.fingerprint(),
|
||||
os.path.basename(options['filename']))
|
||||
except:
|
||||
sys.exit('bad key')
|
||||
|
||||
|
||||
|
||||
def changePassPhrase(options):
|
||||
if not options['filename']:
|
||||
filename = os.path.expanduser('~/.ssh/id_rsa')
|
||||
options['filename'] = raw_input(
|
||||
'Enter file in which the key is (%s): ' % filename)
|
||||
try:
|
||||
key = keys.Key.fromFile(options['filename']).keyObject
|
||||
except keys.EncryptedKeyError as e:
|
||||
# Raised if password not supplied for an encrypted key
|
||||
if not options.get('pass'):
|
||||
options['pass'] = getpass.getpass('Enter old passphrase: ')
|
||||
try:
|
||||
key = keys.Key.fromFile(
|
||||
options['filename'], passphrase=options['pass']).keyObject
|
||||
except keys.BadKeyError:
|
||||
sys.exit('Could not change passphrase: old passphrase error')
|
||||
except keys.EncryptedKeyError as e:
|
||||
sys.exit('Could not change passphrase: %s' % (e,))
|
||||
except keys.BadKeyError as e:
|
||||
sys.exit('Could not change passphrase: %s' % (e,))
|
||||
|
||||
if not options.get('newpass'):
|
||||
while 1:
|
||||
p1 = getpass.getpass(
|
||||
'Enter new passphrase (empty for no passphrase): ')
|
||||
p2 = getpass.getpass('Enter same passphrase again: ')
|
||||
if p1 == p2:
|
||||
break
|
||||
print 'Passphrases do not match. Try again.'
|
||||
options['newpass'] = p1
|
||||
|
||||
try:
|
||||
newkeydata = keys.Key(key).toString('openssh',
|
||||
extra=options['newpass'])
|
||||
except Exception as e:
|
||||
sys.exit('Could not change passphrase: %s' % (e,))
|
||||
|
||||
try:
|
||||
keys.Key.fromString(newkeydata, passphrase=options['newpass'])
|
||||
except (keys.EncryptedKeyError, keys.BadKeyError) as e:
|
||||
sys.exit('Could not change passphrase: %s' % (e,))
|
||||
|
||||
fd = open(options['filename'], 'w')
|
||||
fd.write(newkeydata)
|
||||
fd.close()
|
||||
|
||||
print 'Your identification has been saved with the new passphrase.'
|
||||
|
||||
|
||||
|
||||
def displayPublicKey(options):
|
||||
if not options['filename']:
|
||||
filename = os.path.expanduser('~/.ssh/id_rsa')
|
||||
options['filename'] = raw_input('Enter file in which the key is (%s): ' % filename)
|
||||
try:
|
||||
key = keys.Key.fromFile(options['filename']).keyObject
|
||||
except keys.EncryptedKeyError:
|
||||
if not options.get('pass'):
|
||||
options['pass'] = getpass.getpass('Enter passphrase: ')
|
||||
key = keys.Key.fromFile(
|
||||
options['filename'], passphrase = options['pass']).keyObject
|
||||
print keys.Key(key).public().toString('openssh')
|
||||
|
||||
|
||||
|
||||
def _saveKey(key, options):
|
||||
if not options['filename']:
|
||||
kind = keys.objectType(key)
|
||||
kind = {'ssh-rsa':'rsa','ssh-dss':'dsa'}[kind]
|
||||
filename = os.path.expanduser('~/.ssh/id_%s'%kind)
|
||||
options['filename'] = raw_input('Enter file in which to save the key (%s): '%filename).strip() or filename
|
||||
if os.path.exists(options['filename']):
|
||||
print '%s already exists.' % options['filename']
|
||||
yn = raw_input('Overwrite (y/n)? ')
|
||||
if yn[0].lower() != 'y':
|
||||
sys.exit()
|
||||
if options.get('no-passphrase'):
|
||||
options['pass'] = b''
|
||||
elif not options['pass']:
|
||||
while 1:
|
||||
p1 = getpass.getpass('Enter passphrase (empty for no passphrase): ')
|
||||
p2 = getpass.getpass('Enter same passphrase again: ')
|
||||
if p1 == p2:
|
||||
break
|
||||
print 'Passphrases do not match. Try again.'
|
||||
options['pass'] = p1
|
||||
|
||||
keyObj = keys.Key(key)
|
||||
comment = '%s@%s' % (getpass.getuser(), socket.gethostname())
|
||||
|
||||
filepath.FilePath(options['filename']).setContent(
|
||||
keyObj.toString('openssh', options['pass']))
|
||||
os.chmod(options['filename'], 33152)
|
||||
|
||||
filepath.FilePath(options['filename'] + '.pub').setContent(
|
||||
keyObj.public().toString('openssh', comment))
|
||||
|
||||
print 'Your identification has been saved in %s' % options['filename']
|
||||
print 'Your public key has been saved in %s.pub' % options['filename']
|
||||
print 'The key fingerprint is:'
|
||||
print keyObj.fingerprint()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
@ -0,0 +1,508 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_conch -*-
|
||||
#
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
# $Id: conch.py,v 1.65 2004/03/11 00:29:14 z3p Exp $
|
||||
|
||||
#""" Implementation module for the `conch` command.
|
||||
#"""
|
||||
from twisted.conch.client import connect, default, options
|
||||
from twisted.conch.error import ConchError
|
||||
from twisted.conch.ssh import connection, common
|
||||
from twisted.conch.ssh import session, forwarding, channel
|
||||
from twisted.internet import reactor, stdio, task
|
||||
from twisted.python import log, usage
|
||||
|
||||
import os, sys, getpass, struct, tty, fcntl, signal
|
||||
|
||||
class ClientOptions(options.ConchOptions):
|
||||
|
||||
synopsis = """Usage: conch [options] host [command]
|
||||
"""
|
||||
longdesc = ("conch is a SSHv2 client that allows logging into a remote "
|
||||
"machine and executing commands.")
|
||||
|
||||
optParameters = [['escape', 'e', '~'],
|
||||
['localforward', 'L', None, 'listen-port:host:port Forward local port to remote address'],
|
||||
['remoteforward', 'R', None, 'listen-port:host:port Forward remote port to local address'],
|
||||
]
|
||||
|
||||
optFlags = [['null', 'n', 'Redirect input from /dev/null.'],
|
||||
['fork', 'f', 'Fork to background after authentication.'],
|
||||
['tty', 't', 'Tty; allocate a tty even if command is given.'],
|
||||
['notty', 'T', 'Do not allocate a tty.'],
|
||||
['noshell', 'N', 'Do not execute a shell or command.'],
|
||||
['subsystem', 's', 'Invoke command (mandatory) as SSH2 subsystem.'],
|
||||
]
|
||||
|
||||
compData = usage.Completions(
|
||||
mutuallyExclusive=[("tty", "notty")],
|
||||
optActions={
|
||||
"localforward": usage.Completer(descr="listen-port:host:port"),
|
||||
"remoteforward": usage.Completer(descr="listen-port:host:port")},
|
||||
extraActions=[usage.CompleteUserAtHost(),
|
||||
usage.Completer(descr="command"),
|
||||
usage.Completer(descr="argument", repeat=True)]
|
||||
)
|
||||
|
||||
localForwards = []
|
||||
remoteForwards = []
|
||||
|
||||
def opt_escape(self, esc):
|
||||
"Set escape character; ``none'' = disable"
|
||||
if esc == 'none':
|
||||
self['escape'] = None
|
||||
elif esc[0] == '^' and len(esc) == 2:
|
||||
self['escape'] = chr(ord(esc[1])-64)
|
||||
elif len(esc) == 1:
|
||||
self['escape'] = esc
|
||||
else:
|
||||
sys.exit("Bad escape character '%s'." % esc)
|
||||
|
||||
def opt_localforward(self, f):
|
||||
"Forward local port to remote address (lport:host:port)"
|
||||
localPort, remoteHost, remotePort = f.split(':') # doesn't do v6 yet
|
||||
localPort = int(localPort)
|
||||
remotePort = int(remotePort)
|
||||
self.localForwards.append((localPort, (remoteHost, remotePort)))
|
||||
|
||||
def opt_remoteforward(self, f):
|
||||
"""Forward remote port to local address (rport:host:port)"""
|
||||
remotePort, connHost, connPort = f.split(':') # doesn't do v6 yet
|
||||
remotePort = int(remotePort)
|
||||
connPort = int(connPort)
|
||||
self.remoteForwards.append((remotePort, (connHost, connPort)))
|
||||
|
||||
def parseArgs(self, host, *command):
|
||||
self['host'] = host
|
||||
self['command'] = ' '.join(command)
|
||||
|
||||
# Rest of code in "run"
|
||||
options = None
|
||||
conn = None
|
||||
exitStatus = 0
|
||||
old = None
|
||||
_inRawMode = 0
|
||||
_savedRawMode = None
|
||||
|
||||
def run():
|
||||
global options, old
|
||||
args = sys.argv[1:]
|
||||
if '-l' in args: # cvs is an idiot
|
||||
i = args.index('-l')
|
||||
args = args[i:i+2]+args
|
||||
del args[i+2:i+4]
|
||||
for arg in args[:]:
|
||||
try:
|
||||
i = args.index(arg)
|
||||
if arg[:2] == '-o' and args[i+1][0]!='-':
|
||||
args[i:i+2] = [] # suck on it scp
|
||||
except ValueError:
|
||||
pass
|
||||
options = ClientOptions()
|
||||
try:
|
||||
options.parseOptions(args)
|
||||
except usage.UsageError, u:
|
||||
print 'ERROR: %s' % u
|
||||
options.opt_help()
|
||||
sys.exit(1)
|
||||
if options['log']:
|
||||
if options['logfile']:
|
||||
if options['logfile'] == '-':
|
||||
f = sys.stdout
|
||||
else:
|
||||
f = file(options['logfile'], 'a+')
|
||||
else:
|
||||
f = sys.stderr
|
||||
realout = sys.stdout
|
||||
log.startLogging(f)
|
||||
sys.stdout = realout
|
||||
else:
|
||||
log.discardLogs()
|
||||
doConnect()
|
||||
fd = sys.stdin.fileno()
|
||||
try:
|
||||
old = tty.tcgetattr(fd)
|
||||
except:
|
||||
old = None
|
||||
try:
|
||||
oldUSR1 = signal.signal(signal.SIGUSR1, lambda *a: reactor.callLater(0, reConnect))
|
||||
except:
|
||||
oldUSR1 = None
|
||||
try:
|
||||
reactor.run()
|
||||
finally:
|
||||
if old:
|
||||
tty.tcsetattr(fd, tty.TCSANOW, old)
|
||||
if oldUSR1:
|
||||
signal.signal(signal.SIGUSR1, oldUSR1)
|
||||
if (options['command'] and options['tty']) or not options['notty']:
|
||||
signal.signal(signal.SIGWINCH, signal.SIG_DFL)
|
||||
if sys.stdout.isatty() and not options['command']:
|
||||
print 'Connection to %s closed.' % options['host']
|
||||
sys.exit(exitStatus)
|
||||
|
||||
def handleError():
|
||||
from twisted.python import failure
|
||||
global exitStatus
|
||||
exitStatus = 2
|
||||
reactor.callLater(0.01, _stopReactor)
|
||||
log.err(failure.Failure())
|
||||
raise
|
||||
|
||||
def _stopReactor():
|
||||
try:
|
||||
reactor.stop()
|
||||
except: pass
|
||||
|
||||
def doConnect():
|
||||
# log.deferr = handleError # HACK
|
||||
if '@' in options['host']:
|
||||
options['user'], options['host'] = options['host'].split('@',1)
|
||||
if not options.identitys:
|
||||
options.identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
|
||||
host = options['host']
|
||||
if not options['user']:
|
||||
options['user'] = getpass.getuser()
|
||||
if not options['port']:
|
||||
options['port'] = 22
|
||||
else:
|
||||
options['port'] = int(options['port'])
|
||||
host = options['host']
|
||||
port = options['port']
|
||||
vhk = default.verifyHostKey
|
||||
uao = default.SSHUserAuthClient(options['user'], options, SSHConnection())
|
||||
connect.connect(host, port, options, vhk, uao).addErrback(_ebExit)
|
||||
|
||||
def _ebExit(f):
|
||||
global exitStatus
|
||||
exitStatus = "conch: exiting with error %s" % f
|
||||
reactor.callLater(0.1, _stopReactor)
|
||||
|
||||
def onConnect():
|
||||
# if keyAgent and options['agent']:
|
||||
# cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal, conn)
|
||||
# cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
|
||||
if hasattr(conn.transport, 'sendIgnore'):
|
||||
_KeepAlive(conn)
|
||||
if options.localForwards:
|
||||
for localPort, hostport in options.localForwards:
|
||||
s = reactor.listenTCP(localPort,
|
||||
forwarding.SSHListenForwardingFactory(conn,
|
||||
hostport,
|
||||
SSHListenClientForwardingChannel))
|
||||
conn.localForwards.append(s)
|
||||
if options.remoteForwards:
|
||||
for remotePort, hostport in options.remoteForwards:
|
||||
log.msg('asking for remote forwarding for %s:%s' %
|
||||
(remotePort, hostport))
|
||||
conn.requestRemoteForwarding(remotePort, hostport)
|
||||
reactor.addSystemEventTrigger('before', 'shutdown', beforeShutdown)
|
||||
if not options['noshell'] or options['agent']:
|
||||
conn.openChannel(SSHSession())
|
||||
if options['fork']:
|
||||
if os.fork():
|
||||
os._exit(0)
|
||||
os.setsid()
|
||||
for i in range(3):
|
||||
try:
|
||||
os.close(i)
|
||||
except OSError, e:
|
||||
import errno
|
||||
if e.errno != errno.EBADF:
|
||||
raise
|
||||
|
||||
def reConnect():
|
||||
beforeShutdown()
|
||||
conn.transport.transport.loseConnection()
|
||||
|
||||
def beforeShutdown():
|
||||
remoteForwards = options.remoteForwards
|
||||
for remotePort, hostport in remoteForwards:
|
||||
log.msg('cancelling %s:%s' % (remotePort, hostport))
|
||||
conn.cancelRemoteForwarding(remotePort)
|
||||
|
||||
def stopConnection():
|
||||
if not options['reconnect']:
|
||||
reactor.callLater(0.1, _stopReactor)
|
||||
|
||||
class _KeepAlive:
|
||||
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
self.globalTimeout = None
|
||||
self.lc = task.LoopingCall(self.sendGlobal)
|
||||
self.lc.start(300)
|
||||
|
||||
def sendGlobal(self):
|
||||
d = self.conn.sendGlobalRequest("conch-keep-alive@twistedmatrix.com",
|
||||
"", wantReply = 1)
|
||||
d.addBoth(self._cbGlobal)
|
||||
self.globalTimeout = reactor.callLater(30, self._ebGlobal)
|
||||
|
||||
def _cbGlobal(self, res):
|
||||
if self.globalTimeout:
|
||||
self.globalTimeout.cancel()
|
||||
self.globalTimeout = None
|
||||
|
||||
def _ebGlobal(self):
|
||||
if self.globalTimeout:
|
||||
self.globalTimeout = None
|
||||
self.conn.transport.loseConnection()
|
||||
|
||||
class SSHConnection(connection.SSHConnection):
|
||||
def serviceStarted(self):
|
||||
global conn
|
||||
conn = self
|
||||
self.localForwards = []
|
||||
self.remoteForwards = {}
|
||||
if not isinstance(self, connection.SSHConnection):
|
||||
# make these fall through
|
||||
del self.__class__.requestRemoteForwarding
|
||||
del self.__class__.cancelRemoteForwarding
|
||||
onConnect()
|
||||
|
||||
def serviceStopped(self):
|
||||
lf = self.localForwards
|
||||
self.localForwards = []
|
||||
for s in lf:
|
||||
s.loseConnection()
|
||||
stopConnection()
|
||||
|
||||
def requestRemoteForwarding(self, remotePort, hostport):
|
||||
data = forwarding.packGlobal_tcpip_forward(('0.0.0.0', remotePort))
|
||||
d = self.sendGlobalRequest('tcpip-forward', data,
|
||||
wantReply=1)
|
||||
log.msg('requesting remote forwarding %s:%s' %(remotePort, hostport))
|
||||
d.addCallback(self._cbRemoteForwarding, remotePort, hostport)
|
||||
d.addErrback(self._ebRemoteForwarding, remotePort, hostport)
|
||||
|
||||
def _cbRemoteForwarding(self, result, remotePort, hostport):
|
||||
log.msg('accepted remote forwarding %s:%s' % (remotePort, hostport))
|
||||
self.remoteForwards[remotePort] = hostport
|
||||
log.msg(repr(self.remoteForwards))
|
||||
|
||||
def _ebRemoteForwarding(self, f, remotePort, hostport):
|
||||
log.msg('remote forwarding %s:%s failed' % (remotePort, hostport))
|
||||
log.msg(f)
|
||||
|
||||
def cancelRemoteForwarding(self, remotePort):
|
||||
data = forwarding.packGlobal_tcpip_forward(('0.0.0.0', remotePort))
|
||||
self.sendGlobalRequest('cancel-tcpip-forward', data)
|
||||
log.msg('cancelling remote forwarding %s' % remotePort)
|
||||
try:
|
||||
del self.remoteForwards[remotePort]
|
||||
except:
|
||||
pass
|
||||
log.msg(repr(self.remoteForwards))
|
||||
|
||||
def channel_forwarded_tcpip(self, windowSize, maxPacket, data):
|
||||
log.msg('%s %s' % ('FTCP', repr(data)))
|
||||
remoteHP, origHP = forwarding.unpackOpen_forwarded_tcpip(data)
|
||||
log.msg(self.remoteForwards)
|
||||
log.msg(remoteHP)
|
||||
if self.remoteForwards.has_key(remoteHP[1]):
|
||||
connectHP = self.remoteForwards[remoteHP[1]]
|
||||
log.msg('connect forwarding %s' % (connectHP,))
|
||||
return SSHConnectForwardingChannel(connectHP,
|
||||
remoteWindow = windowSize,
|
||||
remoteMaxPacket = maxPacket,
|
||||
conn = self)
|
||||
else:
|
||||
raise ConchError(connection.OPEN_CONNECT_FAILED, "don't know about that port")
|
||||
|
||||
# def channel_auth_agent_openssh_com(self, windowSize, maxPacket, data):
|
||||
# if options['agent'] and keyAgent:
|
||||
# return agent.SSHAgentForwardingChannel(remoteWindow = windowSize,
|
||||
# remoteMaxPacket = maxPacket,
|
||||
# conn = self)
|
||||
# else:
|
||||
# return connection.OPEN_CONNECT_FAILED, "don't have an agent"
|
||||
|
||||
def channelClosed(self, channel):
|
||||
log.msg('connection closing %s' % channel)
|
||||
log.msg(self.channels)
|
||||
if len(self.channels) == 1: # just us left
|
||||
log.msg('stopping connection')
|
||||
stopConnection()
|
||||
else:
|
||||
# because of the unix thing
|
||||
self.__class__.__bases__[0].channelClosed(self, channel)
|
||||
|
||||
class SSHSession(channel.SSHChannel):
|
||||
|
||||
name = 'session'
|
||||
|
||||
def channelOpen(self, foo):
|
||||
log.msg('session %s open' % self.id)
|
||||
if options['agent']:
|
||||
d = self.conn.sendRequest(self, 'auth-agent-req@openssh.com', '', wantReply=1)
|
||||
d.addBoth(lambda x:log.msg(x))
|
||||
if options['noshell']: return
|
||||
if (options['command'] and options['tty']) or not options['notty']:
|
||||
_enterRawMode()
|
||||
c = session.SSHSessionClient()
|
||||
if options['escape'] and not options['notty']:
|
||||
self.escapeMode = 1
|
||||
c.dataReceived = self.handleInput
|
||||
else:
|
||||
c.dataReceived = self.write
|
||||
c.connectionLost = lambda x=None,s=self:s.sendEOF()
|
||||
self.stdio = stdio.StandardIO(c)
|
||||
fd = 0
|
||||
if options['subsystem']:
|
||||
self.conn.sendRequest(self, 'subsystem', \
|
||||
common.NS(options['command']))
|
||||
elif options['command']:
|
||||
if options['tty']:
|
||||
term = os.environ['TERM']
|
||||
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
|
||||
winSize = struct.unpack('4H', winsz)
|
||||
ptyReqData = session.packRequest_pty_req(term, winSize, '')
|
||||
self.conn.sendRequest(self, 'pty-req', ptyReqData)
|
||||
signal.signal(signal.SIGWINCH, self._windowResized)
|
||||
self.conn.sendRequest(self, 'exec', \
|
||||
common.NS(options['command']))
|
||||
else:
|
||||
if not options['notty']:
|
||||
term = os.environ['TERM']
|
||||
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
|
||||
winSize = struct.unpack('4H', winsz)
|
||||
ptyReqData = session.packRequest_pty_req(term, winSize, '')
|
||||
self.conn.sendRequest(self, 'pty-req', ptyReqData)
|
||||
signal.signal(signal.SIGWINCH, self._windowResized)
|
||||
self.conn.sendRequest(self, 'shell', '')
|
||||
#if hasattr(conn.transport, 'transport'):
|
||||
# conn.transport.transport.setTcpNoDelay(1)
|
||||
|
||||
def handleInput(self, char):
|
||||
#log.msg('handling %s' % repr(char))
|
||||
if char in ('\n', '\r'):
|
||||
self.escapeMode = 1
|
||||
self.write(char)
|
||||
elif self.escapeMode == 1 and char == options['escape']:
|
||||
self.escapeMode = 2
|
||||
elif self.escapeMode == 2:
|
||||
self.escapeMode = 1 # so we can chain escapes together
|
||||
if char == '.': # disconnect
|
||||
log.msg('disconnecting from escape')
|
||||
stopConnection()
|
||||
return
|
||||
elif char == '\x1a': # ^Z, suspend
|
||||
def _():
|
||||
_leaveRawMode()
|
||||
sys.stdout.flush()
|
||||
sys.stdin.flush()
|
||||
os.kill(os.getpid(), signal.SIGTSTP)
|
||||
_enterRawMode()
|
||||
reactor.callLater(0, _)
|
||||
return
|
||||
elif char == 'R': # rekey connection
|
||||
log.msg('rekeying connection')
|
||||
self.conn.transport.sendKexInit()
|
||||
return
|
||||
elif char == '#': # display connections
|
||||
self.stdio.write('\r\nThe following connections are open:\r\n')
|
||||
channels = self.conn.channels.keys()
|
||||
channels.sort()
|
||||
for channelId in channels:
|
||||
self.stdio.write(' #%i %s\r\n' % (channelId, str(self.conn.channels[channelId])))
|
||||
return
|
||||
self.write('~' + char)
|
||||
else:
|
||||
self.escapeMode = 0
|
||||
self.write(char)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.stdio.write(data)
|
||||
|
||||
def extReceived(self, t, data):
|
||||
if t==connection.EXTENDED_DATA_STDERR:
|
||||
log.msg('got %s stderr data' % len(data))
|
||||
sys.stderr.write(data)
|
||||
|
||||
def eofReceived(self):
|
||||
log.msg('got eof')
|
||||
self.stdio.loseWriteConnection()
|
||||
|
||||
def closeReceived(self):
|
||||
log.msg('remote side closed %s' % self)
|
||||
self.conn.sendClose(self)
|
||||
|
||||
def closed(self):
|
||||
global old
|
||||
log.msg('closed %s' % self)
|
||||
log.msg(repr(self.conn.channels))
|
||||
|
||||
def request_exit_status(self, data):
|
||||
global exitStatus
|
||||
exitStatus = int(struct.unpack('>L', data)[0])
|
||||
log.msg('exit status: %s' % exitStatus)
|
||||
|
||||
def sendEOF(self):
|
||||
self.conn.sendEOF(self)
|
||||
|
||||
def stopWriting(self):
|
||||
self.stdio.pauseProducing()
|
||||
|
||||
def startWriting(self):
|
||||
self.stdio.resumeProducing()
|
||||
|
||||
def _windowResized(self, *args):
|
||||
winsz = fcntl.ioctl(0, tty.TIOCGWINSZ, '12345678')
|
||||
winSize = struct.unpack('4H', winsz)
|
||||
newSize = winSize[1], winSize[0], winSize[2], winSize[3]
|
||||
self.conn.sendRequest(self, 'window-change', struct.pack('!4L', *newSize))
|
||||
|
||||
|
||||
class SSHListenClientForwardingChannel(forwarding.SSHListenClientForwardingChannel): pass
|
||||
class SSHConnectForwardingChannel(forwarding.SSHConnectForwardingChannel): pass
|
||||
|
||||
def _leaveRawMode():
|
||||
global _inRawMode
|
||||
if not _inRawMode:
|
||||
return
|
||||
fd = sys.stdin.fileno()
|
||||
tty.tcsetattr(fd, tty.TCSANOW, _savedRawMode)
|
||||
_inRawMode = 0
|
||||
|
||||
def _enterRawMode():
|
||||
global _inRawMode, _savedRawMode
|
||||
if _inRawMode:
|
||||
return
|
||||
fd = sys.stdin.fileno()
|
||||
try:
|
||||
old = tty.tcgetattr(fd)
|
||||
new = old[:]
|
||||
except:
|
||||
log.msg('not a typewriter!')
|
||||
else:
|
||||
# iflage
|
||||
new[0] = new[0] | tty.IGNPAR
|
||||
new[0] = new[0] & ~(tty.ISTRIP | tty.INLCR | tty.IGNCR | tty.ICRNL |
|
||||
tty.IXON | tty.IXANY | tty.IXOFF)
|
||||
if hasattr(tty, 'IUCLC'):
|
||||
new[0] = new[0] & ~tty.IUCLC
|
||||
|
||||
# lflag
|
||||
new[3] = new[3] & ~(tty.ISIG | tty.ICANON | tty.ECHO | tty.ECHO |
|
||||
tty.ECHOE | tty.ECHOK | tty.ECHONL)
|
||||
if hasattr(tty, 'IEXTEN'):
|
||||
new[3] = new[3] & ~tty.IEXTEN
|
||||
|
||||
#oflag
|
||||
new[1] = new[1] & ~tty.OPOST
|
||||
|
||||
new[6][tty.VMIN] = 1
|
||||
new[6][tty.VTIME] = 0
|
||||
|
||||
_savedRawMode = old
|
||||
tty.tcsetattr(fd, tty.TCSANOW, new)
|
||||
#tty.setraw(fd)
|
||||
_inRawMode = 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
|
@ -0,0 +1,573 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_scripts -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation module for the `tkconch` command.
|
||||
"""
|
||||
|
||||
import Tkinter, tkFileDialog, tkMessageBox
|
||||
from twisted.conch import error
|
||||
from twisted.conch.ui import tkvt100
|
||||
from twisted.conch.ssh import transport, userauth, connection, common, keys
|
||||
from twisted.conch.ssh import session, forwarding, channel
|
||||
from twisted.conch.client.default import isInKnownHosts
|
||||
from twisted.internet import reactor, defer, protocol, tksupport
|
||||
from twisted.python import usage, log
|
||||
|
||||
import os, sys, getpass, struct, base64, signal
|
||||
|
||||
class TkConchMenu(Tkinter.Frame):
|
||||
def __init__(self, *args, **params):
|
||||
## Standard heading: initialization
|
||||
apply(Tkinter.Frame.__init__, (self,) + args, params)
|
||||
|
||||
self.master.title('TkConch')
|
||||
self.localRemoteVar = Tkinter.StringVar()
|
||||
self.localRemoteVar.set('local')
|
||||
|
||||
Tkinter.Label(self, anchor='w', justify='left', text='Hostname').grid(column=1, row=1, sticky='w')
|
||||
self.host = Tkinter.Entry(self)
|
||||
self.host.grid(column=2, columnspan=2, row=1, sticky='nesw')
|
||||
|
||||
Tkinter.Label(self, anchor='w', justify='left', text='Port').grid(column=1, row=2, sticky='w')
|
||||
self.port = Tkinter.Entry(self)
|
||||
self.port.grid(column=2, columnspan=2, row=2, sticky='nesw')
|
||||
|
||||
Tkinter.Label(self, anchor='w', justify='left', text='Username').grid(column=1, row=3, sticky='w')
|
||||
self.user = Tkinter.Entry(self)
|
||||
self.user.grid(column=2, columnspan=2, row=3, sticky='nesw')
|
||||
|
||||
Tkinter.Label(self, anchor='w', justify='left', text='Command').grid(column=1, row=4, sticky='w')
|
||||
self.command = Tkinter.Entry(self)
|
||||
self.command.grid(column=2, columnspan=2, row=4, sticky='nesw')
|
||||
|
||||
Tkinter.Label(self, anchor='w', justify='left', text='Identity').grid(column=1, row=5, sticky='w')
|
||||
self.identity = Tkinter.Entry(self)
|
||||
self.identity.grid(column=2, row=5, sticky='nesw')
|
||||
Tkinter.Button(self, command=self.getIdentityFile, text='Browse').grid(column=3, row=5, sticky='nesw')
|
||||
|
||||
Tkinter.Label(self, text='Port Forwarding').grid(column=1, row=6, sticky='w')
|
||||
self.forwards = Tkinter.Listbox(self, height=0, width=0)
|
||||
self.forwards.grid(column=2, columnspan=2, row=6, sticky='nesw')
|
||||
Tkinter.Button(self, text='Add', command=self.addForward).grid(column=1, row=7)
|
||||
Tkinter.Button(self, text='Remove', command=self.removeForward).grid(column=1, row=8)
|
||||
self.forwardPort = Tkinter.Entry(self)
|
||||
self.forwardPort.grid(column=2, row=7, sticky='nesw')
|
||||
Tkinter.Label(self, text='Port').grid(column=3, row=7, sticky='nesw')
|
||||
self.forwardHost = Tkinter.Entry(self)
|
||||
self.forwardHost.grid(column=2, row=8, sticky='nesw')
|
||||
Tkinter.Label(self, text='Host').grid(column=3, row=8, sticky='nesw')
|
||||
self.localForward = Tkinter.Radiobutton(self, text='Local', variable=self.localRemoteVar, value='local')
|
||||
self.localForward.grid(column=2, row=9)
|
||||
self.remoteForward = Tkinter.Radiobutton(self, text='Remote', variable=self.localRemoteVar, value='remote')
|
||||
self.remoteForward.grid(column=3, row=9)
|
||||
|
||||
Tkinter.Label(self, text='Advanced Options').grid(column=1, columnspan=3, row=10, sticky='nesw')
|
||||
|
||||
Tkinter.Label(self, anchor='w', justify='left', text='Cipher').grid(column=1, row=11, sticky='w')
|
||||
self.cipher = Tkinter.Entry(self, name='cipher')
|
||||
self.cipher.grid(column=2, columnspan=2, row=11, sticky='nesw')
|
||||
|
||||
Tkinter.Label(self, anchor='w', justify='left', text='MAC').grid(column=1, row=12, sticky='w')
|
||||
self.mac = Tkinter.Entry(self, name='mac')
|
||||
self.mac.grid(column=2, columnspan=2, row=12, sticky='nesw')
|
||||
|
||||
Tkinter.Label(self, anchor='w', justify='left', text='Escape Char').grid(column=1, row=13, sticky='w')
|
||||
self.escape = Tkinter.Entry(self, name='escape')
|
||||
self.escape.grid(column=2, columnspan=2, row=13, sticky='nesw')
|
||||
Tkinter.Button(self, text='Connect!', command=self.doConnect).grid(column=1, columnspan=3, row=14, sticky='nesw')
|
||||
|
||||
# Resize behavior(s)
|
||||
self.grid_rowconfigure(6, weight=1, minsize=64)
|
||||
self.grid_columnconfigure(2, weight=1, minsize=2)
|
||||
|
||||
self.master.protocol("WM_DELETE_WINDOW", sys.exit)
|
||||
|
||||
|
||||
def getIdentityFile(self):
|
||||
r = tkFileDialog.askopenfilename()
|
||||
if r:
|
||||
self.identity.delete(0, Tkinter.END)
|
||||
self.identity.insert(Tkinter.END, r)
|
||||
|
||||
def addForward(self):
|
||||
port = self.forwardPort.get()
|
||||
self.forwardPort.delete(0, Tkinter.END)
|
||||
host = self.forwardHost.get()
|
||||
self.forwardHost.delete(0, Tkinter.END)
|
||||
if self.localRemoteVar.get() == 'local':
|
||||
self.forwards.insert(Tkinter.END, 'L:%s:%s' % (port, host))
|
||||
else:
|
||||
self.forwards.insert(Tkinter.END, 'R:%s:%s' % (port, host))
|
||||
|
||||
def removeForward(self):
|
||||
cur = self.forwards.curselection()
|
||||
if cur:
|
||||
self.forwards.remove(cur[0])
|
||||
|
||||
def doConnect(self):
|
||||
finished = 1
|
||||
options['host'] = self.host.get()
|
||||
options['port'] = self.port.get()
|
||||
options['user'] = self.user.get()
|
||||
options['command'] = self.command.get()
|
||||
cipher = self.cipher.get()
|
||||
mac = self.mac.get()
|
||||
escape = self.escape.get()
|
||||
if cipher:
|
||||
if cipher in SSHClientTransport.supportedCiphers:
|
||||
SSHClientTransport.supportedCiphers = [cipher]
|
||||
else:
|
||||
tkMessageBox.showerror('TkConch', 'Bad cipher.')
|
||||
finished = 0
|
||||
|
||||
if mac:
|
||||
if mac in SSHClientTransport.supportedMACs:
|
||||
SSHClientTransport.supportedMACs = [mac]
|
||||
elif finished:
|
||||
tkMessageBox.showerror('TkConch', 'Bad MAC.')
|
||||
finished = 0
|
||||
|
||||
if escape:
|
||||
if escape == 'none':
|
||||
options['escape'] = None
|
||||
elif escape[0] == '^' and len(escape) == 2:
|
||||
options['escape'] = chr(ord(escape[1])-64)
|
||||
elif len(escape) == 1:
|
||||
options['escape'] = escape
|
||||
elif finished:
|
||||
tkMessageBox.showerror('TkConch', "Bad escape character '%s'." % escape)
|
||||
finished = 0
|
||||
|
||||
if self.identity.get():
|
||||
options.identitys.append(self.identity.get())
|
||||
|
||||
for line in self.forwards.get(0,Tkinter.END):
|
||||
if line[0]=='L':
|
||||
options.opt_localforward(line[2:])
|
||||
else:
|
||||
options.opt_remoteforward(line[2:])
|
||||
|
||||
if '@' in options['host']:
|
||||
options['user'], options['host'] = options['host'].split('@',1)
|
||||
|
||||
if (not options['host'] or not options['user']) and finished:
|
||||
tkMessageBox.showerror('TkConch', 'Missing host or username.')
|
||||
finished = 0
|
||||
if finished:
|
||||
self.master.quit()
|
||||
self.master.destroy()
|
||||
if options['log']:
|
||||
realout = sys.stdout
|
||||
log.startLogging(sys.stderr)
|
||||
sys.stdout = realout
|
||||
else:
|
||||
log.discardLogs()
|
||||
log.deferr = handleError # HACK
|
||||
if not options.identitys:
|
||||
options.identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
|
||||
host = options['host']
|
||||
port = int(options['port'] or 22)
|
||||
log.msg((host,port))
|
||||
reactor.connectTCP(host, port, SSHClientFactory())
|
||||
frame.master.deiconify()
|
||||
frame.master.title('%s@%s - TkConch' % (options['user'], options['host']))
|
||||
else:
|
||||
self.focus()
|
||||
|
||||
class GeneralOptions(usage.Options):
|
||||
synopsis = """Usage: tkconch [options] host [command]
|
||||
"""
|
||||
|
||||
optParameters = [['user', 'l', None, 'Log in using this user name.'],
|
||||
['identity', 'i', '~/.ssh/identity', 'Identity for public key authentication'],
|
||||
['escape', 'e', '~', "Set escape character; ``none'' = disable"],
|
||||
['cipher', 'c', None, 'Select encryption algorithm.'],
|
||||
['macs', 'm', None, 'Specify MAC algorithms for protocol version 2.'],
|
||||
['port', 'p', None, 'Connect to this port. Server must be on the same port.'],
|
||||
['localforward', 'L', None, 'listen-port:host:port Forward local port to remote address'],
|
||||
['remoteforward', 'R', None, 'listen-port:host:port Forward remote port to local address'],
|
||||
]
|
||||
|
||||
optFlags = [['tty', 't', 'Tty; allocate a tty even if command is given.'],
|
||||
['notty', 'T', 'Do not allocate a tty.'],
|
||||
['version', 'V', 'Display version number only.'],
|
||||
['compress', 'C', 'Enable compression.'],
|
||||
['noshell', 'N', 'Do not execute a shell or command.'],
|
||||
['subsystem', 's', 'Invoke command (mandatory) as SSH2 subsystem.'],
|
||||
['log', 'v', 'Log to stderr'],
|
||||
['ansilog', 'a', 'Print the received data to stdout']]
|
||||
|
||||
_ciphers = transport.SSHClientTransport.supportedCiphers
|
||||
_macs = transport.SSHClientTransport.supportedMACs
|
||||
|
||||
compData = usage.Completions(
|
||||
mutuallyExclusive=[("tty", "notty")],
|
||||
optActions={
|
||||
"cipher": usage.CompleteList(_ciphers),
|
||||
"macs": usage.CompleteList(_macs),
|
||||
"localforward": usage.Completer(descr="listen-port:host:port"),
|
||||
"remoteforward": usage.Completer(descr="listen-port:host:port")},
|
||||
extraActions=[usage.CompleteUserAtHost(),
|
||||
usage.Completer(descr="command"),
|
||||
usage.Completer(descr="argument", repeat=True)]
|
||||
)
|
||||
|
||||
identitys = []
|
||||
localForwards = []
|
||||
remoteForwards = []
|
||||
|
||||
def opt_identity(self, i):
|
||||
self.identitys.append(i)
|
||||
|
||||
def opt_localforward(self, f):
|
||||
localPort, remoteHost, remotePort = f.split(':') # doesn't do v6 yet
|
||||
localPort = int(localPort)
|
||||
remotePort = int(remotePort)
|
||||
self.localForwards.append((localPort, (remoteHost, remotePort)))
|
||||
|
||||
def opt_remoteforward(self, f):
|
||||
remotePort, connHost, connPort = f.split(':') # doesn't do v6 yet
|
||||
remotePort = int(remotePort)
|
||||
connPort = int(connPort)
|
||||
self.remoteForwards.append((remotePort, (connHost, connPort)))
|
||||
|
||||
def opt_compress(self):
|
||||
SSHClientTransport.supportedCompressions[0:1] = ['zlib']
|
||||
|
||||
def parseArgs(self, *args):
|
||||
if args:
|
||||
self['host'] = args[0]
|
||||
self['command'] = ' '.join(args[1:])
|
||||
else:
|
||||
self['host'] = ''
|
||||
self['command'] = ''
|
||||
|
||||
# Rest of code in "run"
|
||||
options = None
|
||||
menu = None
|
||||
exitStatus = 0
|
||||
frame = None
|
||||
|
||||
def deferredAskFrame(question, echo):
|
||||
if frame.callback:
|
||||
raise ValueError("can't ask 2 questions at once!")
|
||||
d = defer.Deferred()
|
||||
resp = []
|
||||
def gotChar(ch, resp=resp):
|
||||
if not ch: return
|
||||
if ch=='\x03': # C-c
|
||||
reactor.stop()
|
||||
if ch=='\r':
|
||||
frame.write('\r\n')
|
||||
stresp = ''.join(resp)
|
||||
del resp
|
||||
frame.callback = None
|
||||
d.callback(stresp)
|
||||
return
|
||||
elif 32 <= ord(ch) < 127:
|
||||
resp.append(ch)
|
||||
if echo:
|
||||
frame.write(ch)
|
||||
elif ord(ch) == 8 and resp: # BS
|
||||
if echo: frame.write('\x08 \x08')
|
||||
resp.pop()
|
||||
frame.callback = gotChar
|
||||
frame.write(question)
|
||||
frame.canvas.focus_force()
|
||||
return d
|
||||
|
||||
def run():
|
||||
global menu, options, frame
|
||||
args = sys.argv[1:]
|
||||
if '-l' in args: # cvs is an idiot
|
||||
i = args.index('-l')
|
||||
args = args[i:i+2]+args
|
||||
del args[i+2:i+4]
|
||||
for arg in args[:]:
|
||||
try:
|
||||
i = args.index(arg)
|
||||
if arg[:2] == '-o' and args[i+1][0]!='-':
|
||||
args[i:i+2] = [] # suck on it scp
|
||||
except ValueError:
|
||||
pass
|
||||
root = Tkinter.Tk()
|
||||
root.withdraw()
|
||||
top = Tkinter.Toplevel()
|
||||
menu = TkConchMenu(top)
|
||||
menu.pack(side=Tkinter.TOP, fill=Tkinter.BOTH, expand=1)
|
||||
options = GeneralOptions()
|
||||
try:
|
||||
options.parseOptions(args)
|
||||
except usage.UsageError, u:
|
||||
print 'ERROR: %s' % u
|
||||
options.opt_help()
|
||||
sys.exit(1)
|
||||
for k,v in options.items():
|
||||
if v and hasattr(menu, k):
|
||||
getattr(menu,k).insert(Tkinter.END, v)
|
||||
for (p, (rh, rp)) in options.localForwards:
|
||||
menu.forwards.insert(Tkinter.END, 'L:%s:%s:%s' % (p, rh, rp))
|
||||
options.localForwards = []
|
||||
for (p, (rh, rp)) in options.remoteForwards:
|
||||
menu.forwards.insert(Tkinter.END, 'R:%s:%s:%s' % (p, rh, rp))
|
||||
options.remoteForwards = []
|
||||
frame = tkvt100.VT100Frame(root, callback=None)
|
||||
root.geometry('%dx%d'%(tkvt100.fontWidth*frame.width+3, tkvt100.fontHeight*frame.height+3))
|
||||
frame.pack(side = Tkinter.TOP)
|
||||
tksupport.install(root)
|
||||
root.withdraw()
|
||||
if (options['host'] and options['user']) or '@' in options['host']:
|
||||
menu.doConnect()
|
||||
else:
|
||||
top.mainloop()
|
||||
reactor.run()
|
||||
sys.exit(exitStatus)
|
||||
|
||||
def handleError():
|
||||
from twisted.python import failure
|
||||
global exitStatus
|
||||
exitStatus = 2
|
||||
log.err(failure.Failure())
|
||||
reactor.stop()
|
||||
raise
|
||||
|
||||
class SSHClientFactory(protocol.ClientFactory):
|
||||
noisy = 1
|
||||
|
||||
def stopFactory(self):
|
||||
reactor.stop()
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return SSHClientTransport()
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
tkMessageBox.showwarning('TkConch','Connection Failed, Reason:\n %s: %s' % (reason.type, reason.value))
|
||||
|
||||
class SSHClientTransport(transport.SSHClientTransport):
|
||||
|
||||
def receiveError(self, code, desc):
|
||||
global exitStatus
|
||||
exitStatus = 'conch:\tRemote side disconnected with error code %i\nconch:\treason: %s' % (code, desc)
|
||||
|
||||
def sendDisconnect(self, code, reason):
|
||||
global exitStatus
|
||||
exitStatus = 'conch:\tSending disconnect with error code %i\nconch:\treason: %s' % (code, reason)
|
||||
transport.SSHClientTransport.sendDisconnect(self, code, reason)
|
||||
|
||||
def receiveDebug(self, alwaysDisplay, message, lang):
|
||||
global options
|
||||
if alwaysDisplay or options['log']:
|
||||
log.msg('Received Debug Message: %s' % message)
|
||||
|
||||
def verifyHostKey(self, pubKey, fingerprint):
|
||||
#d = defer.Deferred()
|
||||
#d.addCallback(lambda x:defer.succeed(1))
|
||||
#d.callback(2)
|
||||
#return d
|
||||
goodKey = isInKnownHosts(options['host'], pubKey, {'known-hosts': None})
|
||||
if goodKey == 1: # good key
|
||||
return defer.succeed(1)
|
||||
elif goodKey == 2: # AAHHHHH changed
|
||||
return defer.fail(error.ConchError('bad host key'))
|
||||
else:
|
||||
if options['host'] == self.transport.getPeer()[1]:
|
||||
host = options['host']
|
||||
khHost = options['host']
|
||||
else:
|
||||
host = '%s (%s)' % (options['host'],
|
||||
self.transport.getPeer()[1])
|
||||
khHost = '%s,%s' % (options['host'],
|
||||
self.transport.getPeer()[1])
|
||||
keyType = common.getNS(pubKey)[0]
|
||||
ques = """The authenticity of host '%s' can't be established.\r
|
||||
%s key fingerprint is %s.""" % (host,
|
||||
{'ssh-dss':'DSA', 'ssh-rsa':'RSA'}[keyType],
|
||||
fingerprint)
|
||||
ques+='\r\nAre you sure you want to continue connecting (yes/no)? '
|
||||
return deferredAskFrame(ques, 1).addCallback(self._cbVerifyHostKey, pubKey, khHost, keyType)
|
||||
|
||||
def _cbVerifyHostKey(self, ans, pubKey, khHost, keyType):
|
||||
if ans.lower() not in ('yes', 'no'):
|
||||
return deferredAskFrame("Please type 'yes' or 'no': ",1).addCallback(self._cbVerifyHostKey, pubKey, khHost, keyType)
|
||||
if ans.lower() == 'no':
|
||||
frame.write('Host key verification failed.\r\n')
|
||||
raise error.ConchError('bad host key')
|
||||
try:
|
||||
frame.write("Warning: Permanently added '%s' (%s) to the list of known hosts.\r\n" % (khHost, {'ssh-dss':'DSA', 'ssh-rsa':'RSA'}[keyType]))
|
||||
known_hosts = open(os.path.expanduser('~/.ssh/known_hosts'), 'a')
|
||||
encodedKey = base64.encodestring(pubKey).replace('\n', '')
|
||||
known_hosts.write('\n%s %s %s' % (khHost, keyType, encodedKey))
|
||||
known_hosts.close()
|
||||
except:
|
||||
log.deferr()
|
||||
raise error.ConchError
|
||||
|
||||
def connectionSecure(self):
|
||||
if options['user']:
|
||||
user = options['user']
|
||||
else:
|
||||
user = getpass.getuser()
|
||||
self.requestService(SSHUserAuthClient(user, SSHConnection()))
|
||||
|
||||
class SSHUserAuthClient(userauth.SSHUserAuthClient):
|
||||
usedFiles = []
|
||||
|
||||
def getPassword(self, prompt = None):
|
||||
if not prompt:
|
||||
prompt = "%s@%s's password: " % (self.user, options['host'])
|
||||
return deferredAskFrame(prompt,0)
|
||||
|
||||
def getPublicKey(self):
|
||||
files = [x for x in options.identitys if x not in self.usedFiles]
|
||||
if not files:
|
||||
return None
|
||||
file = files[0]
|
||||
log.msg(file)
|
||||
self.usedFiles.append(file)
|
||||
file = os.path.expanduser(file)
|
||||
file += '.pub'
|
||||
if not os.path.exists(file):
|
||||
return
|
||||
try:
|
||||
return keys.Key.fromFile(file).blob()
|
||||
except:
|
||||
return self.getPublicKey() # try again
|
||||
|
||||
def getPrivateKey(self):
|
||||
file = os.path.expanduser(self.usedFiles[-1])
|
||||
if not os.path.exists(file):
|
||||
return None
|
||||
try:
|
||||
return defer.succeed(keys.Key.fromFile(file).keyObject)
|
||||
except keys.BadKeyError, e:
|
||||
if e.args[0] == 'encrypted key with no password':
|
||||
prompt = "Enter passphrase for key '%s': " % \
|
||||
self.usedFiles[-1]
|
||||
return deferredAskFrame(prompt, 0).addCallback(self._cbGetPrivateKey, 0)
|
||||
def _cbGetPrivateKey(self, ans, count):
|
||||
file = os.path.expanduser(self.usedFiles[-1])
|
||||
try:
|
||||
return keys.Key.fromFile(file, password = ans).keyObject
|
||||
except keys.BadKeyError:
|
||||
if count == 2:
|
||||
raise
|
||||
prompt = "Enter passphrase for key '%s': " % \
|
||||
self.usedFiles[-1]
|
||||
return deferredAskFrame(prompt, 0).addCallback(self._cbGetPrivateKey, count+1)
|
||||
|
||||
class SSHConnection(connection.SSHConnection):
|
||||
def serviceStarted(self):
|
||||
if not options['noshell']:
|
||||
self.openChannel(SSHSession())
|
||||
if options.localForwards:
|
||||
for localPort, hostport in options.localForwards:
|
||||
reactor.listenTCP(localPort,
|
||||
forwarding.SSHListenForwardingFactory(self,
|
||||
hostport,
|
||||
forwarding.SSHListenClientForwardingChannel))
|
||||
if options.remoteForwards:
|
||||
for remotePort, hostport in options.remoteForwards:
|
||||
log.msg('asking for remote forwarding for %s:%s' %
|
||||
(remotePort, hostport))
|
||||
data = forwarding.packGlobal_tcpip_forward(
|
||||
('0.0.0.0', remotePort))
|
||||
self.sendGlobalRequest('tcpip-forward', data)
|
||||
self.remoteForwards[remotePort] = hostport
|
||||
|
||||
class SSHSession(channel.SSHChannel):
|
||||
|
||||
name = 'session'
|
||||
|
||||
def channelOpen(self, foo):
|
||||
#global globalSession
|
||||
#globalSession = self
|
||||
# turn off local echo
|
||||
self.escapeMode = 1
|
||||
c = session.SSHSessionClient()
|
||||
if options['escape']:
|
||||
c.dataReceived = self.handleInput
|
||||
else:
|
||||
c.dataReceived = self.write
|
||||
c.connectionLost = self.sendEOF
|
||||
frame.callback = c.dataReceived
|
||||
frame.canvas.focus_force()
|
||||
if options['subsystem']:
|
||||
self.conn.sendRequest(self, 'subsystem', \
|
||||
common.NS(options['command']))
|
||||
elif options['command']:
|
||||
if options['tty']:
|
||||
term = os.environ.get('TERM', 'xterm')
|
||||
#winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
|
||||
winSize = (25,80,0,0) #struct.unpack('4H', winsz)
|
||||
ptyReqData = session.packRequest_pty_req(term, winSize, '')
|
||||
self.conn.sendRequest(self, 'pty-req', ptyReqData)
|
||||
self.conn.sendRequest(self, 'exec', \
|
||||
common.NS(options['command']))
|
||||
else:
|
||||
if not options['notty']:
|
||||
term = os.environ.get('TERM', 'xterm')
|
||||
#winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
|
||||
winSize = (25,80,0,0) #struct.unpack('4H', winsz)
|
||||
ptyReqData = session.packRequest_pty_req(term, winSize, '')
|
||||
self.conn.sendRequest(self, 'pty-req', ptyReqData)
|
||||
self.conn.sendRequest(self, 'shell', '')
|
||||
self.conn.transport.transport.setTcpNoDelay(1)
|
||||
|
||||
def handleInput(self, char):
|
||||
#log.msg('handling %s' % repr(char))
|
||||
if char in ('\n', '\r'):
|
||||
self.escapeMode = 1
|
||||
self.write(char)
|
||||
elif self.escapeMode == 1 and char == options['escape']:
|
||||
self.escapeMode = 2
|
||||
elif self.escapeMode == 2:
|
||||
self.escapeMode = 1 # so we can chain escapes together
|
||||
if char == '.': # disconnect
|
||||
log.msg('disconnecting from escape')
|
||||
reactor.stop()
|
||||
return
|
||||
elif char == '\x1a': # ^Z, suspend
|
||||
# following line courtesy of Erwin@freenode
|
||||
os.kill(os.getpid(), signal.SIGSTOP)
|
||||
return
|
||||
elif char == 'R': # rekey connection
|
||||
log.msg('rekeying connection')
|
||||
self.conn.transport.sendKexInit()
|
||||
return
|
||||
self.write('~' + char)
|
||||
else:
|
||||
self.escapeMode = 0
|
||||
self.write(char)
|
||||
|
||||
def dataReceived(self, data):
|
||||
if options['ansilog']:
|
||||
print repr(data)
|
||||
frame.write(data)
|
||||
|
||||
def extReceived(self, t, data):
|
||||
if t==connection.EXTENDED_DATA_STDERR:
|
||||
log.msg('got %s stderr data' % len(data))
|
||||
sys.stderr.write(data)
|
||||
sys.stderr.flush()
|
||||
|
||||
def eofReceived(self):
|
||||
log.msg('got eof')
|
||||
sys.stdin.close()
|
||||
|
||||
def closed(self):
|
||||
log.msg('closed %s' % self)
|
||||
if len(self.conn.channels) == 1: # just us left
|
||||
reactor.stop()
|
||||
|
||||
def request_exit_status(self, data):
|
||||
global exitStatus
|
||||
exitStatus = int(struct.unpack('>L', data)[0])
|
||||
log.msg('exit status: %s' % exitStatus)
|
||||
|
||||
def sendEOF(self):
|
||||
self.conn.sendEOF(self)
|
||||
|
||||
if __name__=="__main__":
|
||||
run()
|
@ -0,0 +1,10 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
|
||||
"""
|
||||
An SSHv2 implementation for Twisted. Part of the Twisted.Conch package.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
@ -0,0 +1,41 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_address -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Address object for SSH network connections.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
|
||||
@since: 12.1
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.interfaces import IAddress
|
||||
from twisted.python import util
|
||||
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
class SSHTransportAddress(object, util.FancyEqMixin):
|
||||
"""
|
||||
Object representing an SSH Transport endpoint.
|
||||
|
||||
@ivar address: A instance of an object which implements I{IAddress} to
|
||||
which this transport address is connected.
|
||||
"""
|
||||
|
||||
compareAttributes = ('address',)
|
||||
|
||||
def __init__(self, address):
|
||||
self.address = address
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
return 'SSHTransportAddress(%r)' % (self.address,)
|
||||
|
||||
|
||||
def __hash__(self):
|
||||
return hash(('SSH', self.address))
|
||||
|
@ -0,0 +1,294 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implements the SSH v2 key agent protocol. This protocol is documented in the
|
||||
SSH source code, in the file
|
||||
U{PROTOCOL.agent<http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent>}.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.conch.ssh.common import NS, getNS, getMP
|
||||
from twisted.conch.error import ConchError, MissingKeyStoreError
|
||||
from twisted.conch.ssh import keys
|
||||
from twisted.internet import defer, protocol
|
||||
|
||||
|
||||
|
||||
class SSHAgentClient(protocol.Protocol):
|
||||
"""
|
||||
The client side of the SSH agent protocol. This is equivalent to
|
||||
ssh-add(1) and can be used with either ssh-agent(1) or the SSHAgentServer
|
||||
protocol, also in this package.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.buf = ''
|
||||
self.deferreds = []
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buf += data
|
||||
while 1:
|
||||
if len(self.buf) <= 4:
|
||||
return
|
||||
packLen = struct.unpack('!L', self.buf[:4])[0]
|
||||
if len(self.buf) < 4 + packLen:
|
||||
return
|
||||
packet, self.buf = self.buf[4:4 + packLen], self.buf[4 + packLen:]
|
||||
reqType = ord(packet[0])
|
||||
d = self.deferreds.pop(0)
|
||||
if reqType == AGENT_FAILURE:
|
||||
d.errback(ConchError('agent failure'))
|
||||
elif reqType == AGENT_SUCCESS:
|
||||
d.callback('')
|
||||
else:
|
||||
d.callback(packet)
|
||||
|
||||
|
||||
def sendRequest(self, reqType, data):
|
||||
pack = struct.pack('!LB',len(data) + 1, reqType) + data
|
||||
self.transport.write(pack)
|
||||
d = defer.Deferred()
|
||||
self.deferreds.append(d)
|
||||
return d
|
||||
|
||||
|
||||
def requestIdentities(self):
|
||||
"""
|
||||
@return: A L{Deferred} which will fire with a list of all keys found in
|
||||
the SSH agent. The list of keys is comprised of (public key blob,
|
||||
comment) tuples.
|
||||
"""
|
||||
d = self.sendRequest(AGENTC_REQUEST_IDENTITIES, '')
|
||||
d.addCallback(self._cbRequestIdentities)
|
||||
return d
|
||||
|
||||
|
||||
def _cbRequestIdentities(self, data):
|
||||
"""
|
||||
Unpack a collection of identities into a list of tuples comprised of
|
||||
public key blobs and comments.
|
||||
"""
|
||||
if ord(data[0]) != AGENT_IDENTITIES_ANSWER:
|
||||
raise ConchError('unexpected response: %i' % ord(data[0]))
|
||||
numKeys = struct.unpack('!L', data[1:5])[0]
|
||||
result = []
|
||||
data = data[5:]
|
||||
for i in range(numKeys):
|
||||
blob, data = getNS(data)
|
||||
comment, data = getNS(data)
|
||||
result.append((blob, comment))
|
||||
return result
|
||||
|
||||
|
||||
def addIdentity(self, blob, comment = ''):
|
||||
"""
|
||||
Add a private key blob to the agent's collection of keys.
|
||||
"""
|
||||
req = blob
|
||||
req += NS(comment)
|
||||
return self.sendRequest(AGENTC_ADD_IDENTITY, req)
|
||||
|
||||
|
||||
def signData(self, blob, data):
|
||||
"""
|
||||
Request that the agent sign the given C{data} with the private key
|
||||
which corresponds to the public key given by C{blob}. The private
|
||||
key should have been added to the agent already.
|
||||
|
||||
@type blob: C{str}
|
||||
@type data: C{str}
|
||||
@return: A L{Deferred} which fires with a signature for given data
|
||||
created with the given key.
|
||||
"""
|
||||
req = NS(blob)
|
||||
req += NS(data)
|
||||
req += '\000\000\000\000' # flags
|
||||
return self.sendRequest(AGENTC_SIGN_REQUEST, req).addCallback(self._cbSignData)
|
||||
|
||||
|
||||
def _cbSignData(self, data):
|
||||
if ord(data[0]) != AGENT_SIGN_RESPONSE:
|
||||
raise ConchError('unexpected data: %i' % ord(data[0]))
|
||||
signature = getNS(data[1:])[0]
|
||||
return signature
|
||||
|
||||
|
||||
def removeIdentity(self, blob):
|
||||
"""
|
||||
Remove the private key corresponding to the public key in blob from the
|
||||
running agent.
|
||||
"""
|
||||
req = NS(blob)
|
||||
return self.sendRequest(AGENTC_REMOVE_IDENTITY, req)
|
||||
|
||||
|
||||
def removeAllIdentities(self):
|
||||
"""
|
||||
Remove all keys from the running agent.
|
||||
"""
|
||||
return self.sendRequest(AGENTC_REMOVE_ALL_IDENTITIES, '')
|
||||
|
||||
|
||||
|
||||
class SSHAgentServer(protocol.Protocol):
|
||||
"""
|
||||
The server side of the SSH agent protocol. This is equivalent to
|
||||
ssh-agent(1) and can be used with either ssh-add(1) or the SSHAgentClient
|
||||
protocol, also in this package.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.buf = ''
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buf += data
|
||||
while 1:
|
||||
if len(self.buf) <= 4:
|
||||
return
|
||||
packLen = struct.unpack('!L', self.buf[:4])[0]
|
||||
if len(self.buf) < 4 + packLen:
|
||||
return
|
||||
packet, self.buf = self.buf[4:4 + packLen], self.buf[4 + packLen:]
|
||||
reqType = ord(packet[0])
|
||||
reqName = messages.get(reqType, None)
|
||||
if not reqName:
|
||||
self.sendResponse(AGENT_FAILURE, '')
|
||||
else:
|
||||
f = getattr(self, 'agentc_%s' % reqName)
|
||||
if getattr(self.factory, 'keys', None) is None:
|
||||
self.sendResponse(AGENT_FAILURE, '')
|
||||
raise MissingKeyStoreError()
|
||||
f(packet[1:])
|
||||
|
||||
|
||||
def sendResponse(self, reqType, data):
|
||||
pack = struct.pack('!LB', len(data) + 1, reqType) + data
|
||||
self.transport.write(pack)
|
||||
|
||||
|
||||
def agentc_REQUEST_IDENTITIES(self, data):
|
||||
"""
|
||||
Return all of the identities that have been added to the server
|
||||
"""
|
||||
assert data == ''
|
||||
numKeys = len(self.factory.keys)
|
||||
resp = []
|
||||
|
||||
resp.append(struct.pack('!L', numKeys))
|
||||
for key, comment in self.factory.keys.itervalues():
|
||||
resp.append(NS(key.blob())) # yes, wrapped in an NS
|
||||
resp.append(NS(comment))
|
||||
self.sendResponse(AGENT_IDENTITIES_ANSWER, ''.join(resp))
|
||||
|
||||
|
||||
def agentc_SIGN_REQUEST(self, data):
|
||||
"""
|
||||
Data is a structure with a reference to an already added key object and
|
||||
some data that the clients wants signed with that key. If the key
|
||||
object wasn't loaded, return AGENT_FAILURE, else return the signature.
|
||||
"""
|
||||
blob, data = getNS(data)
|
||||
if blob not in self.factory.keys:
|
||||
return self.sendResponse(AGENT_FAILURE, '')
|
||||
signData, data = getNS(data)
|
||||
assert data == '\000\000\000\000'
|
||||
self.sendResponse(AGENT_SIGN_RESPONSE, NS(self.factory.keys[blob][0].sign(signData)))
|
||||
|
||||
|
||||
def agentc_ADD_IDENTITY(self, data):
|
||||
"""
|
||||
Adds a private key to the agent's collection of identities. On
|
||||
subsequent interactions, the private key can be accessed using only the
|
||||
corresponding public key.
|
||||
"""
|
||||
|
||||
# need to pre-read the key data so we can get past it to the comment string
|
||||
keyType, rest = getNS(data)
|
||||
if keyType == 'ssh-rsa':
|
||||
nmp = 6
|
||||
elif keyType == 'ssh-dss':
|
||||
nmp = 5
|
||||
else:
|
||||
raise keys.BadKeyError('unknown blob type: %s' % keyType)
|
||||
|
||||
rest = getMP(rest, nmp)[-1] # ignore the key data for now, we just want the comment
|
||||
comment, rest = getNS(rest) # the comment, tacked onto the end of the key blob
|
||||
|
||||
k = keys.Key.fromString(data, type='private_blob') # not wrapped in NS here
|
||||
self.factory.keys[k.blob()] = (k, comment)
|
||||
self.sendResponse(AGENT_SUCCESS, '')
|
||||
|
||||
|
||||
def agentc_REMOVE_IDENTITY(self, data):
|
||||
"""
|
||||
Remove a specific key from the agent's collection of identities.
|
||||
"""
|
||||
blob, _ = getNS(data)
|
||||
k = keys.Key.fromString(blob, type='blob')
|
||||
del self.factory.keys[k.blob()]
|
||||
self.sendResponse(AGENT_SUCCESS, '')
|
||||
|
||||
|
||||
def agentc_REMOVE_ALL_IDENTITIES(self, data):
|
||||
"""
|
||||
Remove all keys from the agent's collection of identities.
|
||||
"""
|
||||
assert data == ''
|
||||
self.factory.keys = {}
|
||||
self.sendResponse(AGENT_SUCCESS, '')
|
||||
|
||||
# v1 messages that we ignore because we don't keep v1 keys
|
||||
# open-ssh sends both v1 and v2 commands, so we have to
|
||||
# do no-ops for v1 commands or we'll get "bad request" errors
|
||||
|
||||
def agentc_REQUEST_RSA_IDENTITIES(self, data):
|
||||
"""
|
||||
v1 message for listing RSA1 keys; superseded by
|
||||
agentc_REQUEST_IDENTITIES, which handles different key types.
|
||||
"""
|
||||
self.sendResponse(AGENT_RSA_IDENTITIES_ANSWER, struct.pack('!L', 0))
|
||||
|
||||
|
||||
def agentc_REMOVE_RSA_IDENTITY(self, data):
|
||||
"""
|
||||
v1 message for removing RSA1 keys; superseded by
|
||||
agentc_REMOVE_IDENTITY, which handles different key types.
|
||||
"""
|
||||
self.sendResponse(AGENT_SUCCESS, '')
|
||||
|
||||
|
||||
def agentc_REMOVE_ALL_RSA_IDENTITIES(self, data):
|
||||
"""
|
||||
v1 message for removing all RSA1 keys; superseded by
|
||||
agentc_REMOVE_ALL_IDENTITIES, which handles different key types.
|
||||
"""
|
||||
self.sendResponse(AGENT_SUCCESS, '')
|
||||
|
||||
|
||||
AGENTC_REQUEST_RSA_IDENTITIES = 1
|
||||
AGENT_RSA_IDENTITIES_ANSWER = 2
|
||||
AGENT_FAILURE = 5
|
||||
AGENT_SUCCESS = 6
|
||||
|
||||
AGENTC_REMOVE_RSA_IDENTITY = 8
|
||||
AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
|
||||
|
||||
AGENTC_REQUEST_IDENTITIES = 11
|
||||
AGENT_IDENTITIES_ANSWER = 12
|
||||
AGENTC_SIGN_REQUEST = 13
|
||||
AGENT_SIGN_RESPONSE = 14
|
||||
AGENTC_ADD_IDENTITY = 17
|
||||
AGENTC_REMOVE_IDENTITY = 18
|
||||
AGENTC_REMOVE_ALL_IDENTITIES = 19
|
||||
|
||||
messages = {}
|
||||
for name, value in locals().copy().items():
|
||||
if name[:7] == 'AGENTC_':
|
||||
messages[value] = name[7:] # doesn't handle doubles
|
||||
|
@ -0,0 +1,281 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_channel -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
"""
|
||||
The parent class for all the SSH Channels. Currently implemented channels
|
||||
are session. direct-tcp, and forwarded-tcp.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.python import log
|
||||
from twisted.internet import interfaces
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport)
|
||||
class SSHChannel(log.Logger):
|
||||
"""
|
||||
A class that represents a multiplexed channel over an SSH connection.
|
||||
The channel has a local window which is the maximum amount of data it will
|
||||
receive, and a remote which is the maximum amount of data the remote side
|
||||
will accept. There is also a maximum packet size for any individual data
|
||||
packet going each way.
|
||||
|
||||
@ivar name: the name of the channel.
|
||||
@type name: C{str}
|
||||
@ivar localWindowSize: the maximum size of the local window in bytes.
|
||||
@type localWindowSize: C{int}
|
||||
@ivar localWindowLeft: how many bytes are left in the local window.
|
||||
@type localWindowLeft: C{int}
|
||||
@ivar localMaxPacket: the maximum size of packet we will accept in bytes.
|
||||
@type localMaxPacket: C{int}
|
||||
@ivar remoteWindowLeft: how many bytes are left in the remote window.
|
||||
@type remoteWindowLeft: C{int}
|
||||
@ivar remoteMaxPacket: the maximum size of a packet the remote side will
|
||||
accept in bytes.
|
||||
@type remoteMaxPacket: C{int}
|
||||
@ivar conn: the connection this channel is multiplexed through.
|
||||
@type conn: L{SSHConnection}
|
||||
@ivar data: any data to send to the other size when the channel is
|
||||
requested.
|
||||
@type data: C{str}
|
||||
@ivar avatar: an avatar for the logged-in user (if a server channel)
|
||||
@ivar localClosed: True if we aren't accepting more data.
|
||||
@type localClosed: C{bool}
|
||||
@ivar remoteClosed: True if the other size isn't accepting more data.
|
||||
@type remoteClosed: C{bool}
|
||||
"""
|
||||
|
||||
name = None # only needed for client channels
|
||||
|
||||
def __init__(self, localWindow = 0, localMaxPacket = 0,
|
||||
remoteWindow = 0, remoteMaxPacket = 0,
|
||||
conn = None, data=None, avatar = None):
|
||||
self.localWindowSize = localWindow or 131072
|
||||
self.localWindowLeft = self.localWindowSize
|
||||
self.localMaxPacket = localMaxPacket or 32768
|
||||
self.remoteWindowLeft = remoteWindow
|
||||
self.remoteMaxPacket = remoteMaxPacket
|
||||
self.areWriting = 1
|
||||
self.conn = conn
|
||||
self.data = data
|
||||
self.avatar = avatar
|
||||
self.specificData = ''
|
||||
self.buf = ''
|
||||
self.extBuf = []
|
||||
self.closing = 0
|
||||
self.localClosed = 0
|
||||
self.remoteClosed = 0
|
||||
self.id = None # gets set later by SSHConnection
|
||||
|
||||
def __str__(self):
|
||||
return '<SSHChannel %s (lw %i rw %i)>' % (self.name,
|
||||
self.localWindowLeft, self.remoteWindowLeft)
|
||||
|
||||
def logPrefix(self):
|
||||
id = (self.id is not None and str(self.id)) or "unknown"
|
||||
return "SSHChannel %s (%s) on %s" % (self.name, id,
|
||||
self.conn.logPrefix())
|
||||
|
||||
def channelOpen(self, specificData):
|
||||
"""
|
||||
Called when the channel is opened. specificData is any data that the
|
||||
other side sent us when opening the channel.
|
||||
|
||||
@type specificData: C{str}
|
||||
"""
|
||||
log.msg('channel open')
|
||||
|
||||
def openFailed(self, reason):
|
||||
"""
|
||||
Called when the open failed for some reason.
|
||||
reason.desc is a string descrption, reason.code the SSH error code.
|
||||
|
||||
@type reason: L{error.ConchError}
|
||||
"""
|
||||
log.msg('other side refused open\nreason: %s'% reason)
|
||||
|
||||
def addWindowBytes(self, bytes):
|
||||
"""
|
||||
Called when bytes are added to the remote window. By default it clears
|
||||
the data buffers.
|
||||
|
||||
@type bytes: C{int}
|
||||
"""
|
||||
self.remoteWindowLeft = self.remoteWindowLeft+bytes
|
||||
if not self.areWriting and not self.closing:
|
||||
self.areWriting = True
|
||||
self.startWriting()
|
||||
if self.buf:
|
||||
b = self.buf
|
||||
self.buf = ''
|
||||
self.write(b)
|
||||
if self.extBuf:
|
||||
b = self.extBuf
|
||||
self.extBuf = []
|
||||
for (type, data) in b:
|
||||
self.writeExtended(type, data)
|
||||
|
||||
def requestReceived(self, requestType, data):
|
||||
"""
|
||||
Called when a request is sent to this channel. By default it delegates
|
||||
to self.request_<requestType>.
|
||||
If this function returns true, the request succeeded, otherwise it
|
||||
failed.
|
||||
|
||||
@type requestType: C{str}
|
||||
@type data: C{str}
|
||||
@rtype: C{bool}
|
||||
"""
|
||||
foo = requestType.replace('-', '_')
|
||||
f = getattr(self, 'request_%s'%foo, None)
|
||||
if f:
|
||||
return f(data)
|
||||
log.msg('unhandled request for %s'%requestType)
|
||||
return 0
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Called when we receive data.
|
||||
|
||||
@type data: C{str}
|
||||
"""
|
||||
log.msg('got data %s'%repr(data))
|
||||
|
||||
def extReceived(self, dataType, data):
|
||||
"""
|
||||
Called when we receive extended data (usually standard error).
|
||||
|
||||
@type dataType: C{int}
|
||||
@type data: C{str}
|
||||
"""
|
||||
log.msg('got extended data %s %s'%(dataType, repr(data)))
|
||||
|
||||
def eofReceived(self):
|
||||
"""
|
||||
Called when the other side will send no more data.
|
||||
"""
|
||||
log.msg('remote eof')
|
||||
|
||||
def closeReceived(self):
|
||||
"""
|
||||
Called when the other side has closed the channel.
|
||||
"""
|
||||
log.msg('remote close')
|
||||
self.loseConnection()
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
Called when the channel is closed. This means that both our side and
|
||||
the remote side have closed the channel.
|
||||
"""
|
||||
log.msg('closed')
|
||||
|
||||
# transport stuff
|
||||
def write(self, data):
|
||||
"""
|
||||
Write some data to the channel. If there is not enough remote window
|
||||
available, buffer until it is. Otherwise, split the data into
|
||||
packets of length remoteMaxPacket and send them.
|
||||
|
||||
@type data: C{str}
|
||||
"""
|
||||
if self.buf:
|
||||
self.buf += data
|
||||
return
|
||||
top = len(data)
|
||||
if top > self.remoteWindowLeft:
|
||||
data, self.buf = (data[:self.remoteWindowLeft],
|
||||
data[self.remoteWindowLeft:])
|
||||
self.areWriting = 0
|
||||
self.stopWriting()
|
||||
top = self.remoteWindowLeft
|
||||
rmp = self.remoteMaxPacket
|
||||
write = self.conn.sendData
|
||||
r = range(0, top, rmp)
|
||||
for offset in r:
|
||||
write(self, data[offset: offset+rmp])
|
||||
self.remoteWindowLeft -= top
|
||||
if self.closing and not self.buf:
|
||||
self.loseConnection() # try again
|
||||
|
||||
def writeExtended(self, dataType, data):
|
||||
"""
|
||||
Send extended data to this channel. If there is not enough remote
|
||||
window available, buffer until there is. Otherwise, split the data
|
||||
into packets of length remoteMaxPacket and send them.
|
||||
|
||||
@type dataType: C{int}
|
||||
@type data: C{str}
|
||||
"""
|
||||
if self.extBuf:
|
||||
if self.extBuf[-1][0] == dataType:
|
||||
self.extBuf[-1][1] += data
|
||||
else:
|
||||
self.extBuf.append([dataType, data])
|
||||
return
|
||||
if len(data) > self.remoteWindowLeft:
|
||||
data, self.extBuf = (data[:self.remoteWindowLeft],
|
||||
[[dataType, data[self.remoteWindowLeft:]]])
|
||||
self.areWriting = 0
|
||||
self.stopWriting()
|
||||
while len(data) > self.remoteMaxPacket:
|
||||
self.conn.sendExtendedData(self, dataType,
|
||||
data[:self.remoteMaxPacket])
|
||||
data = data[self.remoteMaxPacket:]
|
||||
self.remoteWindowLeft -= self.remoteMaxPacket
|
||||
if data:
|
||||
self.conn.sendExtendedData(self, dataType, data)
|
||||
self.remoteWindowLeft -= len(data)
|
||||
if self.closing:
|
||||
self.loseConnection() # try again
|
||||
|
||||
def writeSequence(self, data):
|
||||
"""
|
||||
Part of the Transport interface. Write a list of strings to the
|
||||
channel.
|
||||
|
||||
@type data: C{list} of C{str}
|
||||
"""
|
||||
self.write(''.join(data))
|
||||
|
||||
def loseConnection(self):
|
||||
"""
|
||||
Close the channel if there is no buferred data. Otherwise, note the
|
||||
request and return.
|
||||
"""
|
||||
self.closing = 1
|
||||
if not self.buf and not self.extBuf:
|
||||
self.conn.sendClose(self)
|
||||
|
||||
def getPeer(self):
|
||||
"""
|
||||
Return a tuple describing the other side of the connection.
|
||||
|
||||
@rtype: C{tuple}
|
||||
"""
|
||||
return('SSH', )+self.conn.transport.getPeer()
|
||||
|
||||
def getHost(self):
|
||||
"""
|
||||
Return a tuple describing our side of the connection.
|
||||
|
||||
@rtype: C{tuple}
|
||||
"""
|
||||
return('SSH', )+self.conn.transport.getHost()
|
||||
|
||||
def stopWriting(self):
|
||||
"""
|
||||
Called when the remote buffer is full, as a hint to stop writing.
|
||||
This can be ignored, but it can be helpful.
|
||||
"""
|
||||
|
||||
def startWriting(self):
|
||||
"""
|
||||
Called when the remote buffer has more room, as a hint to continue
|
||||
writing.
|
||||
"""
|
@ -0,0 +1,116 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_ssh -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Common functions for the SSH classes.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
import struct, warnings, __builtin__
|
||||
|
||||
try:
|
||||
from Crypto import Util
|
||||
except ImportError:
|
||||
warnings.warn("PyCrypto not installed, but continuing anyways!",
|
||||
RuntimeWarning)
|
||||
|
||||
|
||||
|
||||
def NS(t):
|
||||
"""
|
||||
net string
|
||||
"""
|
||||
return struct.pack('!L',len(t)) + t
|
||||
|
||||
def getNS(s, count=1):
|
||||
"""
|
||||
get net string
|
||||
"""
|
||||
ns = []
|
||||
c = 0
|
||||
for i in range(count):
|
||||
l, = struct.unpack('!L',s[c:c+4])
|
||||
ns.append(s[c+4:4+l+c])
|
||||
c += 4 + l
|
||||
return tuple(ns) + (s[c:],)
|
||||
|
||||
def MP(number):
|
||||
if number==0: return '\000'*4
|
||||
assert number>0
|
||||
bn = Util.number.long_to_bytes(number)
|
||||
if ord(bn[0])&128:
|
||||
bn = '\000' + bn
|
||||
return struct.pack('>L',len(bn)) + bn
|
||||
|
||||
def getMP(data, count=1):
|
||||
"""
|
||||
Get multiple precision integer out of the string. A multiple precision
|
||||
integer is stored as a 4-byte length followed by length bytes of the
|
||||
integer. If count is specified, get count integers out of the string.
|
||||
The return value is a tuple of count integers followed by the rest of
|
||||
the data.
|
||||
"""
|
||||
mp = []
|
||||
c = 0
|
||||
for i in range(count):
|
||||
length, = struct.unpack('>L',data[c:c+4])
|
||||
mp.append(Util.number.bytes_to_long(data[c+4:c+4+length]))
|
||||
c += 4 + length
|
||||
return tuple(mp) + (data[c:],)
|
||||
|
||||
def _MPpow(x, y, z):
|
||||
"""return the MP version of (x**y)%z
|
||||
"""
|
||||
return MP(pow(x,y,z))
|
||||
|
||||
def ffs(c, s):
|
||||
"""
|
||||
first from second
|
||||
goes through the first list, looking for items in the second, returns the first one
|
||||
"""
|
||||
for i in c:
|
||||
if i in s: return i
|
||||
|
||||
getMP_py = getMP
|
||||
MP_py = MP
|
||||
_MPpow_py = _MPpow
|
||||
pyPow = pow
|
||||
|
||||
def _fastgetMP(data, count=1):
|
||||
mp = []
|
||||
c = 0
|
||||
for i in range(count):
|
||||
length = struct.unpack('!L', data[c:c+4])[0]
|
||||
mp.append(long(gmpy.mpz(data[c + 4:c + 4 + length][::-1] + '\x00', 256)))
|
||||
c += length + 4
|
||||
return tuple(mp) + (data[c:],)
|
||||
|
||||
def _fastMP(i):
|
||||
i2 = gmpy.mpz(i).binary()[::-1]
|
||||
return struct.pack('!L', len(i2)) + i2
|
||||
|
||||
def _fastMPpow(x, y, z=None):
|
||||
r = pyPow(gmpy.mpz(x),y,z).binary()[::-1]
|
||||
return struct.pack('!L', len(r)) + r
|
||||
|
||||
def install():
|
||||
global getMP, MP, _MPpow
|
||||
getMP = _fastgetMP
|
||||
MP = _fastMP
|
||||
_MPpow = _fastMPpow
|
||||
# XXX: We override builtin pow so that PyCrypto can benefit from gmpy too.
|
||||
def _fastpow(x, y, z=None, mpz=gmpy.mpz):
|
||||
if type(x) in (long, int):
|
||||
x = mpz(x)
|
||||
return pyPow(x, y, z)
|
||||
__builtin__.pow = _fastpow # evil evil
|
||||
|
||||
try:
|
||||
import gmpy
|
||||
install()
|
||||
except ImportError:
|
||||
pass
|
||||
|
@ -0,0 +1,636 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_connection -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains the implementation of the ssh-connection service, which
|
||||
allows access to the shell and port-forwarding.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.conch.ssh import service, common
|
||||
from twisted.conch import error
|
||||
from twisted.internet import defer
|
||||
from twisted.python import log
|
||||
|
||||
class SSHConnection(service.SSHService):
|
||||
"""
|
||||
An implementation of the 'ssh-connection' service. It is used to
|
||||
multiplex multiple channels over the single SSH connection.
|
||||
|
||||
@ivar localChannelID: the next number to use as a local channel ID.
|
||||
@type localChannelID: C{int}
|
||||
@ivar channels: a C{dict} mapping a local channel ID to C{SSHChannel}
|
||||
subclasses.
|
||||
@type channels: C{dict}
|
||||
@ivar localToRemoteChannel: a C{dict} mapping a local channel ID to a
|
||||
remote channel ID.
|
||||
@type localToRemoteChannel: C{dict}
|
||||
@ivar channelsToRemoteChannel: a C{dict} mapping a C{SSHChannel} subclass
|
||||
to remote channel ID.
|
||||
@type channelsToRemoteChannel: C{dict}
|
||||
@ivar deferreds: a C{dict} mapping a local channel ID to a C{list} of
|
||||
C{Deferreds} for outstanding channel requests. Also, the 'global'
|
||||
key stores the C{list} of pending global request C{Deferred}s.
|
||||
"""
|
||||
name = 'ssh-connection'
|
||||
|
||||
def __init__(self):
|
||||
self.localChannelID = 0 # this is the current # to use for channel ID
|
||||
self.localToRemoteChannel = {} # local channel ID -> remote channel ID
|
||||
self.channels = {} # local channel ID -> subclass of SSHChannel
|
||||
self.channelsToRemoteChannel = {} # subclass of SSHChannel ->
|
||||
# remote channel ID
|
||||
self.deferreds = {"global": []} # local channel -> list of deferreds
|
||||
# for pending requests or 'global' -> list of
|
||||
# deferreds for global requests
|
||||
self.transport = None # gets set later
|
||||
|
||||
|
||||
def serviceStarted(self):
|
||||
if hasattr(self.transport, 'avatar'):
|
||||
self.transport.avatar.conn = self
|
||||
|
||||
|
||||
def serviceStopped(self):
|
||||
"""
|
||||
Called when the connection is stopped.
|
||||
"""
|
||||
map(self.channelClosed, self.channels.values())
|
||||
self._cleanupGlobalDeferreds()
|
||||
|
||||
|
||||
def _cleanupGlobalDeferreds(self):
|
||||
"""
|
||||
All pending requests that have returned a deferred must be errbacked
|
||||
when this service is stopped, otherwise they might be left uncalled and
|
||||
uncallable.
|
||||
"""
|
||||
for d in self.deferreds["global"]:
|
||||
d.errback(error.ConchError("Connection stopped."))
|
||||
del self.deferreds["global"][:]
|
||||
|
||||
|
||||
# packet methods
|
||||
def ssh_GLOBAL_REQUEST(self, packet):
|
||||
"""
|
||||
The other side has made a global request. Payload::
|
||||
string request type
|
||||
bool want reply
|
||||
<request specific data>
|
||||
|
||||
This dispatches to self.gotGlobalRequest.
|
||||
"""
|
||||
requestType, rest = common.getNS(packet)
|
||||
wantReply, rest = ord(rest[0]), rest[1:]
|
||||
ret = self.gotGlobalRequest(requestType, rest)
|
||||
if wantReply:
|
||||
reply = MSG_REQUEST_FAILURE
|
||||
data = ''
|
||||
if ret:
|
||||
reply = MSG_REQUEST_SUCCESS
|
||||
if isinstance(ret, (tuple, list)):
|
||||
data = ret[1]
|
||||
self.transport.sendPacket(reply, data)
|
||||
|
||||
def ssh_REQUEST_SUCCESS(self, packet):
|
||||
"""
|
||||
Our global request succeeded. Get the appropriate Deferred and call
|
||||
it back with the packet we received.
|
||||
"""
|
||||
log.msg('RS')
|
||||
self.deferreds['global'].pop(0).callback(packet)
|
||||
|
||||
def ssh_REQUEST_FAILURE(self, packet):
|
||||
"""
|
||||
Our global request failed. Get the appropriate Deferred and errback
|
||||
it with the packet we received.
|
||||
"""
|
||||
log.msg('RF')
|
||||
self.deferreds['global'].pop(0).errback(
|
||||
error.ConchError('global request failed', packet))
|
||||
|
||||
def ssh_CHANNEL_OPEN(self, packet):
|
||||
"""
|
||||
The other side wants to get a channel. Payload::
|
||||
string channel name
|
||||
uint32 remote channel number
|
||||
uint32 remote window size
|
||||
uint32 remote maximum packet size
|
||||
<channel specific data>
|
||||
|
||||
We get a channel from self.getChannel(), give it a local channel number
|
||||
and notify the other side. Then notify the channel by calling its
|
||||
channelOpen method.
|
||||
"""
|
||||
channelType, rest = common.getNS(packet)
|
||||
senderChannel, windowSize, maxPacket = struct.unpack('>3L', rest[:12])
|
||||
packet = rest[12:]
|
||||
try:
|
||||
channel = self.getChannel(channelType, windowSize, maxPacket,
|
||||
packet)
|
||||
localChannel = self.localChannelID
|
||||
self.localChannelID += 1
|
||||
channel.id = localChannel
|
||||
self.channels[localChannel] = channel
|
||||
self.channelsToRemoteChannel[channel] = senderChannel
|
||||
self.localToRemoteChannel[localChannel] = senderChannel
|
||||
self.transport.sendPacket(MSG_CHANNEL_OPEN_CONFIRMATION,
|
||||
struct.pack('>4L', senderChannel, localChannel,
|
||||
channel.localWindowSize,
|
||||
channel.localMaxPacket)+channel.specificData)
|
||||
log.callWithLogger(channel, channel.channelOpen, packet)
|
||||
except Exception, e:
|
||||
log.err(e, 'channel open failed')
|
||||
if isinstance(e, error.ConchError):
|
||||
textualInfo, reason = e.args
|
||||
if isinstance(textualInfo, (int, long)):
|
||||
# See #3657 and #3071
|
||||
textualInfo, reason = reason, textualInfo
|
||||
else:
|
||||
reason = OPEN_CONNECT_FAILED
|
||||
textualInfo = "unknown failure"
|
||||
self.transport.sendPacket(
|
||||
MSG_CHANNEL_OPEN_FAILURE,
|
||||
struct.pack('>2L', senderChannel, reason) +
|
||||
common.NS(textualInfo) + common.NS(''))
|
||||
|
||||
def ssh_CHANNEL_OPEN_CONFIRMATION(self, packet):
|
||||
"""
|
||||
The other side accepted our MSG_CHANNEL_OPEN request. Payload::
|
||||
uint32 local channel number
|
||||
uint32 remote channel number
|
||||
uint32 remote window size
|
||||
uint32 remote maximum packet size
|
||||
<channel specific data>
|
||||
|
||||
Find the channel using the local channel number and notify its
|
||||
channelOpen method.
|
||||
"""
|
||||
(localChannel, remoteChannel, windowSize,
|
||||
maxPacket) = struct.unpack('>4L', packet[: 16])
|
||||
specificData = packet[16:]
|
||||
channel = self.channels[localChannel]
|
||||
channel.conn = self
|
||||
self.localToRemoteChannel[localChannel] = remoteChannel
|
||||
self.channelsToRemoteChannel[channel] = remoteChannel
|
||||
channel.remoteWindowLeft = windowSize
|
||||
channel.remoteMaxPacket = maxPacket
|
||||
log.callWithLogger(channel, channel.channelOpen, specificData)
|
||||
|
||||
def ssh_CHANNEL_OPEN_FAILURE(self, packet):
|
||||
"""
|
||||
The other side did not accept our MSG_CHANNEL_OPEN request. Payload::
|
||||
uint32 local channel number
|
||||
uint32 reason code
|
||||
string reason description
|
||||
|
||||
Find the channel using the local channel number and notify it by
|
||||
calling its openFailed() method.
|
||||
"""
|
||||
localChannel, reasonCode = struct.unpack('>2L', packet[:8])
|
||||
reasonDesc = common.getNS(packet[8:])[0]
|
||||
channel = self.channels[localChannel]
|
||||
del self.channels[localChannel]
|
||||
channel.conn = self
|
||||
reason = error.ConchError(reasonDesc, reasonCode)
|
||||
log.callWithLogger(channel, channel.openFailed, reason)
|
||||
|
||||
def ssh_CHANNEL_WINDOW_ADJUST(self, packet):
|
||||
"""
|
||||
The other side is adding bytes to its window. Payload::
|
||||
uint32 local channel number
|
||||
uint32 bytes to add
|
||||
|
||||
Call the channel's addWindowBytes() method to add new bytes to the
|
||||
remote window.
|
||||
"""
|
||||
localChannel, bytesToAdd = struct.unpack('>2L', packet[:8])
|
||||
channel = self.channels[localChannel]
|
||||
log.callWithLogger(channel, channel.addWindowBytes, bytesToAdd)
|
||||
|
||||
def ssh_CHANNEL_DATA(self, packet):
|
||||
"""
|
||||
The other side is sending us data. Payload::
|
||||
uint32 local channel number
|
||||
string data
|
||||
|
||||
Check to make sure the other side hasn't sent too much data (more
|
||||
than what's in the window, or more than the maximum packet size). If
|
||||
they have, close the channel. Otherwise, decrease the available
|
||||
window and pass the data to the channel's dataReceived().
|
||||
"""
|
||||
localChannel, dataLength = struct.unpack('>2L', packet[:8])
|
||||
channel = self.channels[localChannel]
|
||||
# XXX should this move to dataReceived to put client in charge?
|
||||
if (dataLength > channel.localWindowLeft or
|
||||
dataLength > channel.localMaxPacket): # more data than we want
|
||||
log.callWithLogger(channel, log.msg, 'too much data')
|
||||
self.sendClose(channel)
|
||||
return
|
||||
#packet = packet[:channel.localWindowLeft+4]
|
||||
data = common.getNS(packet[4:])[0]
|
||||
channel.localWindowLeft -= dataLength
|
||||
if channel.localWindowLeft < channel.localWindowSize // 2:
|
||||
self.adjustWindow(channel, channel.localWindowSize - \
|
||||
channel.localWindowLeft)
|
||||
#log.msg('local window left: %s/%s' % (channel.localWindowLeft,
|
||||
# channel.localWindowSize))
|
||||
log.callWithLogger(channel, channel.dataReceived, data)
|
||||
|
||||
def ssh_CHANNEL_EXTENDED_DATA(self, packet):
|
||||
"""
|
||||
The other side is sending us exteneded data. Payload::
|
||||
uint32 local channel number
|
||||
uint32 type code
|
||||
string data
|
||||
|
||||
Check to make sure the other side hasn't sent too much data (more
|
||||
than what's in the window, or than the maximum packet size). If
|
||||
they have, close the channel. Otherwise, decrease the available
|
||||
window and pass the data and type code to the channel's
|
||||
extReceived().
|
||||
"""
|
||||
localChannel, typeCode, dataLength = struct.unpack('>3L', packet[:12])
|
||||
channel = self.channels[localChannel]
|
||||
if (dataLength > channel.localWindowLeft or
|
||||
dataLength > channel.localMaxPacket):
|
||||
log.callWithLogger(channel, log.msg, 'too much extdata')
|
||||
self.sendClose(channel)
|
||||
return
|
||||
data = common.getNS(packet[8:])[0]
|
||||
channel.localWindowLeft -= dataLength
|
||||
if channel.localWindowLeft < channel.localWindowSize // 2:
|
||||
self.adjustWindow(channel, channel.localWindowSize -
|
||||
channel.localWindowLeft)
|
||||
log.callWithLogger(channel, channel.extReceived, typeCode, data)
|
||||
|
||||
def ssh_CHANNEL_EOF(self, packet):
|
||||
"""
|
||||
The other side is not sending any more data. Payload::
|
||||
uint32 local channel number
|
||||
|
||||
Notify the channel by calling its eofReceived() method.
|
||||
"""
|
||||
localChannel = struct.unpack('>L', packet[:4])[0]
|
||||
channel = self.channels[localChannel]
|
||||
log.callWithLogger(channel, channel.eofReceived)
|
||||
|
||||
def ssh_CHANNEL_CLOSE(self, packet):
|
||||
"""
|
||||
The other side is closing its end; it does not want to receive any
|
||||
more data. Payload::
|
||||
uint32 local channel number
|
||||
|
||||
Notify the channnel by calling its closeReceived() method. If
|
||||
the channel has also sent a close message, call self.channelClosed().
|
||||
"""
|
||||
localChannel = struct.unpack('>L', packet[:4])[0]
|
||||
channel = self.channels[localChannel]
|
||||
log.callWithLogger(channel, channel.closeReceived)
|
||||
channel.remoteClosed = True
|
||||
if channel.localClosed and channel.remoteClosed:
|
||||
self.channelClosed(channel)
|
||||
|
||||
def ssh_CHANNEL_REQUEST(self, packet):
|
||||
"""
|
||||
The other side is sending a request to a channel. Payload::
|
||||
uint32 local channel number
|
||||
string request name
|
||||
bool want reply
|
||||
<request specific data>
|
||||
|
||||
Pass the message to the channel's requestReceived method. If the
|
||||
other side wants a reply, add callbacks which will send the
|
||||
reply.
|
||||
"""
|
||||
localChannel = struct.unpack('>L', packet[:4])[0]
|
||||
requestType, rest = common.getNS(packet[4:])
|
||||
wantReply = ord(rest[0])
|
||||
channel = self.channels[localChannel]
|
||||
d = defer.maybeDeferred(log.callWithLogger, channel,
|
||||
channel.requestReceived, requestType, rest[1:])
|
||||
if wantReply:
|
||||
d.addCallback(self._cbChannelRequest, localChannel)
|
||||
d.addErrback(self._ebChannelRequest, localChannel)
|
||||
return d
|
||||
|
||||
def _cbChannelRequest(self, result, localChannel):
|
||||
"""
|
||||
Called back if the other side wanted a reply to a channel request. If
|
||||
the result is true, send a MSG_CHANNEL_SUCCESS. Otherwise, raise
|
||||
a C{error.ConchError}
|
||||
|
||||
@param result: the value returned from the channel's requestReceived()
|
||||
method. If it's False, the request failed.
|
||||
@type result: C{bool}
|
||||
@param localChannel: the local channel ID of the channel to which the
|
||||
request was made.
|
||||
@type localChannel: C{int}
|
||||
@raises ConchError: if the result is False.
|
||||
"""
|
||||
if not result:
|
||||
raise error.ConchError('failed request')
|
||||
self.transport.sendPacket(MSG_CHANNEL_SUCCESS, struct.pack('>L',
|
||||
self.localToRemoteChannel[localChannel]))
|
||||
|
||||
def _ebChannelRequest(self, result, localChannel):
|
||||
"""
|
||||
Called if the other wisde wanted a reply to the channel requeset and
|
||||
the channel request failed.
|
||||
|
||||
@param result: a Failure, but it's not used.
|
||||
@param localChannel: the local channel ID of the channel to which the
|
||||
request was made.
|
||||
@type localChannel: C{int}
|
||||
"""
|
||||
self.transport.sendPacket(MSG_CHANNEL_FAILURE, struct.pack('>L',
|
||||
self.localToRemoteChannel[localChannel]))
|
||||
|
||||
def ssh_CHANNEL_SUCCESS(self, packet):
|
||||
"""
|
||||
Our channel request to the other side succeeded. Payload::
|
||||
uint32 local channel number
|
||||
|
||||
Get the C{Deferred} out of self.deferreds and call it back.
|
||||
"""
|
||||
localChannel = struct.unpack('>L', packet[:4])[0]
|
||||
if self.deferreds.get(localChannel):
|
||||
d = self.deferreds[localChannel].pop(0)
|
||||
log.callWithLogger(self.channels[localChannel],
|
||||
d.callback, '')
|
||||
|
||||
def ssh_CHANNEL_FAILURE(self, packet):
|
||||
"""
|
||||
Our channel request to the other side failed. Payload::
|
||||
uint32 local channel number
|
||||
|
||||
Get the C{Deferred} out of self.deferreds and errback it with a
|
||||
C{error.ConchError}.
|
||||
"""
|
||||
localChannel = struct.unpack('>L', packet[:4])[0]
|
||||
if self.deferreds.get(localChannel):
|
||||
d = self.deferreds[localChannel].pop(0)
|
||||
log.callWithLogger(self.channels[localChannel],
|
||||
d.errback,
|
||||
error.ConchError('channel request failed'))
|
||||
|
||||
# methods for users of the connection to call
|
||||
|
||||
def sendGlobalRequest(self, request, data, wantReply=0):
|
||||
"""
|
||||
Send a global request for this connection. Current this is only used
|
||||
for remote->local TCP forwarding.
|
||||
|
||||
@type request: C{str}
|
||||
@type data: C{str}
|
||||
@type wantReply: C{bool}
|
||||
@rtype C{Deferred}/C{None}
|
||||
"""
|
||||
self.transport.sendPacket(MSG_GLOBAL_REQUEST,
|
||||
common.NS(request)
|
||||
+ (wantReply and '\xff' or '\x00')
|
||||
+ data)
|
||||
if wantReply:
|
||||
d = defer.Deferred()
|
||||
self.deferreds['global'].append(d)
|
||||
return d
|
||||
|
||||
def openChannel(self, channel, extra=''):
|
||||
"""
|
||||
Open a new channel on this connection.
|
||||
|
||||
@type channel: subclass of C{SSHChannel}
|
||||
@type extra: C{str}
|
||||
"""
|
||||
log.msg('opening channel %s with %s %s'%(self.localChannelID,
|
||||
channel.localWindowSize, channel.localMaxPacket))
|
||||
self.transport.sendPacket(MSG_CHANNEL_OPEN, common.NS(channel.name)
|
||||
+ struct.pack('>3L', self.localChannelID,
|
||||
channel.localWindowSize, channel.localMaxPacket)
|
||||
+ extra)
|
||||
channel.id = self.localChannelID
|
||||
self.channels[self.localChannelID] = channel
|
||||
self.localChannelID += 1
|
||||
|
||||
def sendRequest(self, channel, requestType, data, wantReply=0):
|
||||
"""
|
||||
Send a request to a channel.
|
||||
|
||||
@type channel: subclass of C{SSHChannel}
|
||||
@type requestType: C{str}
|
||||
@type data: C{str}
|
||||
@type wantReply: C{bool}
|
||||
@rtype C{Deferred}/C{None}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return
|
||||
log.msg('sending request %s' % requestType)
|
||||
self.transport.sendPacket(MSG_CHANNEL_REQUEST, struct.pack('>L',
|
||||
self.channelsToRemoteChannel[channel])
|
||||
+ common.NS(requestType)+chr(wantReply)
|
||||
+ data)
|
||||
if wantReply:
|
||||
d = defer.Deferred()
|
||||
self.deferreds.setdefault(channel.id, []).append(d)
|
||||
return d
|
||||
|
||||
def adjustWindow(self, channel, bytesToAdd):
|
||||
"""
|
||||
Tell the other side that we will receive more data. This should not
|
||||
normally need to be called as it is managed automatically.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
@type bytesToAdd: C{int}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
self.transport.sendPacket(MSG_CHANNEL_WINDOW_ADJUST, struct.pack('>2L',
|
||||
self.channelsToRemoteChannel[channel],
|
||||
bytesToAdd))
|
||||
log.msg('adding %i to %i in channel %i' % (bytesToAdd,
|
||||
channel.localWindowLeft, channel.id))
|
||||
channel.localWindowLeft += bytesToAdd
|
||||
|
||||
def sendData(self, channel, data):
|
||||
"""
|
||||
Send data to a channel. This should not normally be used: instead use
|
||||
channel.write(data) as it manages the window automatically.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
@type data: C{str}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
self.transport.sendPacket(MSG_CHANNEL_DATA, struct.pack('>L',
|
||||
self.channelsToRemoteChannel[channel]) +
|
||||
common.NS(data))
|
||||
|
||||
def sendExtendedData(self, channel, dataType, data):
|
||||
"""
|
||||
Send extended data to a channel. This should not normally be used:
|
||||
instead use channel.writeExtendedData(data, dataType) as it manages
|
||||
the window automatically.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
@type dataType: C{int}
|
||||
@type data: C{str}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
self.transport.sendPacket(MSG_CHANNEL_EXTENDED_DATA, struct.pack('>2L',
|
||||
self.channelsToRemoteChannel[channel],dataType) \
|
||||
+ common.NS(data))
|
||||
|
||||
def sendEOF(self, channel):
|
||||
"""
|
||||
Send an EOF (End of File) for a channel.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
log.msg('sending eof')
|
||||
self.transport.sendPacket(MSG_CHANNEL_EOF, struct.pack('>L',
|
||||
self.channelsToRemoteChannel[channel]))
|
||||
|
||||
def sendClose(self, channel):
|
||||
"""
|
||||
Close a channel.
|
||||
|
||||
@type channel: subclass of L{SSHChannel}
|
||||
"""
|
||||
if channel.localClosed:
|
||||
return # we're already closed
|
||||
log.msg('sending close %i' % channel.id)
|
||||
self.transport.sendPacket(MSG_CHANNEL_CLOSE, struct.pack('>L',
|
||||
self.channelsToRemoteChannel[channel]))
|
||||
channel.localClosed = True
|
||||
if channel.localClosed and channel.remoteClosed:
|
||||
self.channelClosed(channel)
|
||||
|
||||
# methods to override
|
||||
def getChannel(self, channelType, windowSize, maxPacket, data):
|
||||
"""
|
||||
The other side requested a channel of some sort.
|
||||
channelType is the type of channel being requested,
|
||||
windowSize is the initial size of the remote window,
|
||||
maxPacket is the largest packet we should send,
|
||||
data is any other packet data (often nothing).
|
||||
|
||||
We return a subclass of L{SSHChannel}.
|
||||
|
||||
By default, this dispatches to a method 'channel_channelType' with any
|
||||
non-alphanumerics in the channelType replace with _'s. If it cannot
|
||||
find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
|
||||
The method is called with arguments of windowSize, maxPacket, data.
|
||||
|
||||
@type channelType: C{str}
|
||||
@type windowSize: C{int}
|
||||
@type maxPacket: C{int}
|
||||
@type data: C{str}
|
||||
@rtype: subclass of L{SSHChannel}/C{tuple}
|
||||
"""
|
||||
log.msg('got channel %s request' % channelType)
|
||||
if hasattr(self.transport, "avatar"): # this is a server!
|
||||
chan = self.transport.avatar.lookupChannel(channelType,
|
||||
windowSize,
|
||||
maxPacket,
|
||||
data)
|
||||
else:
|
||||
channelType = channelType.translate(TRANSLATE_TABLE)
|
||||
f = getattr(self, 'channel_%s' % channelType, None)
|
||||
if f is not None:
|
||||
chan = f(windowSize, maxPacket, data)
|
||||
else:
|
||||
chan = None
|
||||
if chan is None:
|
||||
raise error.ConchError('unknown channel',
|
||||
OPEN_UNKNOWN_CHANNEL_TYPE)
|
||||
else:
|
||||
chan.conn = self
|
||||
return chan
|
||||
|
||||
def gotGlobalRequest(self, requestType, data):
|
||||
"""
|
||||
We got a global request. pretty much, this is just used by the client
|
||||
to request that we forward a port from the server to the client.
|
||||
Returns either:
|
||||
- 1: request accepted
|
||||
- 1, <data>: request accepted with request specific data
|
||||
- 0: request denied
|
||||
|
||||
By default, this dispatches to a method 'global_requestType' with
|
||||
-'s in requestType replaced with _'s. The found method is passed data.
|
||||
If this method cannot be found, this method returns 0. Otherwise, it
|
||||
returns the return value of that method.
|
||||
|
||||
@type requestType: C{str}
|
||||
@type data: C{str}
|
||||
@rtype: C{int}/C{tuple}
|
||||
"""
|
||||
log.msg('got global %s request' % requestType)
|
||||
if hasattr(self.transport, 'avatar'): # this is a server!
|
||||
return self.transport.avatar.gotGlobalRequest(requestType, data)
|
||||
|
||||
requestType = requestType.replace('-','_')
|
||||
f = getattr(self, 'global_%s' % requestType, None)
|
||||
if not f:
|
||||
return 0
|
||||
return f(data)
|
||||
|
||||
def channelClosed(self, channel):
|
||||
"""
|
||||
Called when a channel is closed.
|
||||
It clears the local state related to the channel, and calls
|
||||
channel.closed().
|
||||
MAKE SURE YOU CALL THIS METHOD, even if you subclass L{SSHConnection}.
|
||||
If you don't, things will break mysteriously.
|
||||
|
||||
@type channel: L{SSHChannel}
|
||||
"""
|
||||
if channel in self.channelsToRemoteChannel: # actually open
|
||||
channel.localClosed = channel.remoteClosed = True
|
||||
del self.localToRemoteChannel[channel.id]
|
||||
del self.channels[channel.id]
|
||||
del self.channelsToRemoteChannel[channel]
|
||||
for d in self.deferreds.setdefault(channel.id, []):
|
||||
d.errback(error.ConchError("Channel closed."))
|
||||
del self.deferreds[channel.id][:]
|
||||
log.callWithLogger(channel, channel.closed)
|
||||
|
||||
MSG_GLOBAL_REQUEST = 80
|
||||
MSG_REQUEST_SUCCESS = 81
|
||||
MSG_REQUEST_FAILURE = 82
|
||||
MSG_CHANNEL_OPEN = 90
|
||||
MSG_CHANNEL_OPEN_CONFIRMATION = 91
|
||||
MSG_CHANNEL_OPEN_FAILURE = 92
|
||||
MSG_CHANNEL_WINDOW_ADJUST = 93
|
||||
MSG_CHANNEL_DATA = 94
|
||||
MSG_CHANNEL_EXTENDED_DATA = 95
|
||||
MSG_CHANNEL_EOF = 96
|
||||
MSG_CHANNEL_CLOSE = 97
|
||||
MSG_CHANNEL_REQUEST = 98
|
||||
MSG_CHANNEL_SUCCESS = 99
|
||||
MSG_CHANNEL_FAILURE = 100
|
||||
|
||||
OPEN_ADMINISTRATIVELY_PROHIBITED = 1
|
||||
OPEN_CONNECT_FAILED = 2
|
||||
OPEN_UNKNOWN_CHANNEL_TYPE = 3
|
||||
OPEN_RESOURCE_SHORTAGE = 4
|
||||
|
||||
EXTENDED_DATA_STDERR = 1
|
||||
|
||||
messages = {}
|
||||
for name, value in locals().copy().items():
|
||||
if name[:4] == 'MSG_':
|
||||
messages[value] = name # doesn't handle doubles
|
||||
|
||||
import string
|
||||
alphanums = string.letters + string.digits
|
||||
TRANSLATE_TABLE = ''.join([chr(i) in alphanums and chr(i) or '_'
|
||||
for i in range(256)])
|
||||
SSHConnection.protocolMessages = messages
|
@ -0,0 +1,120 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
A Factory for SSH servers, along with an OpenSSHFactory to use the same
|
||||
data sources as OpenSSH.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
from twisted.internet import protocol
|
||||
from twisted.python import log
|
||||
|
||||
from twisted.conch import error
|
||||
import transport, userauth, connection
|
||||
|
||||
import random
|
||||
|
||||
|
||||
class SSHFactory(protocol.Factory):
|
||||
"""
|
||||
A Factory for SSH servers.
|
||||
"""
|
||||
protocol = transport.SSHServerTransport
|
||||
|
||||
services = {
|
||||
'ssh-userauth':userauth.SSHUserAuthServer,
|
||||
'ssh-connection':connection.SSHConnection
|
||||
}
|
||||
def startFactory(self):
|
||||
"""
|
||||
Check for public and private keys.
|
||||
"""
|
||||
if not hasattr(self,'publicKeys'):
|
||||
self.publicKeys = self.getPublicKeys()
|
||||
if not hasattr(self,'privateKeys'):
|
||||
self.privateKeys = self.getPrivateKeys()
|
||||
if not self.publicKeys or not self.privateKeys:
|
||||
raise error.ConchError('no host keys, failing')
|
||||
if not hasattr(self,'primes'):
|
||||
self.primes = self.getPrimes()
|
||||
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
"""
|
||||
Create an instance of the server side of the SSH protocol.
|
||||
|
||||
@type addr: L{twisted.internet.interfaces.IAddress} provider
|
||||
@param addr: The address at which the server will listen.
|
||||
|
||||
@rtype: L{twisted.conch.ssh.SSHServerTransport}
|
||||
@return: The built transport.
|
||||
"""
|
||||
t = protocol.Factory.buildProtocol(self, addr)
|
||||
t.supportedPublicKeys = self.privateKeys.keys()
|
||||
if not self.primes:
|
||||
log.msg('disabling diffie-hellman-group-exchange because we '
|
||||
'cannot find moduli file')
|
||||
ske = t.supportedKeyExchanges[:]
|
||||
ske.remove('diffie-hellman-group-exchange-sha1')
|
||||
t.supportedKeyExchanges = ske
|
||||
return t
|
||||
|
||||
|
||||
def getPublicKeys(self):
|
||||
"""
|
||||
Called when the factory is started to get the public portions of the
|
||||
servers host keys. Returns a dictionary mapping SSH key types to
|
||||
public key strings.
|
||||
|
||||
@rtype: C{dict}
|
||||
"""
|
||||
raise NotImplementedError('getPublicKeys unimplemented')
|
||||
|
||||
|
||||
def getPrivateKeys(self):
|
||||
"""
|
||||
Called when the factory is started to get the private portions of the
|
||||
servers host keys. Returns a dictionary mapping SSH key types to
|
||||
C{Crypto.PublicKey.pubkey.pubkey} objects.
|
||||
|
||||
@rtype: C{dict}
|
||||
"""
|
||||
raise NotImplementedError('getPrivateKeys unimplemented')
|
||||
|
||||
|
||||
def getPrimes(self):
|
||||
"""
|
||||
Called when the factory is started to get Diffie-Hellman generators and
|
||||
primes to use. Returns a dictionary mapping number of bits to lists
|
||||
of tuple of (generator, prime).
|
||||
|
||||
@rtype: C{dict}
|
||||
"""
|
||||
|
||||
|
||||
def getDHPrime(self, bits):
|
||||
"""
|
||||
Return a tuple of (g, p) for a Diffe-Hellman process, with p being as
|
||||
close to bits bits as possible.
|
||||
|
||||
@type bits: C{int}
|
||||
@rtype: C{tuple}
|
||||
"""
|
||||
primesKeys = self.primes.keys()
|
||||
primesKeys.sort(lambda x, y: cmp(abs(x - bits), abs(y - bits)))
|
||||
realBits = primesKeys[0]
|
||||
return random.choice(self.primes[realBits])
|
||||
|
||||
|
||||
def getService(self, transport, service):
|
||||
"""
|
||||
Return a class to use as a service for the given transport.
|
||||
|
||||
@type transport: L{transport.SSHServerTransport}
|
||||
@type service: C{str}
|
||||
@rtype: subclass of L{service.SSHService}
|
||||
"""
|
||||
if service == 'ssh-userauth' or hasattr(transport, 'avatar'):
|
||||
return self.services[service]
|
@ -0,0 +1,933 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_filetransfer -*-
|
||||
#
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import errno
|
||||
import struct
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.conch.interfaces import ISFTPServer, ISFTPFile
|
||||
from twisted.conch.ssh.common import NS, getNS
|
||||
from twisted.internet import defer, protocol
|
||||
from twisted.python import failure, log
|
||||
|
||||
|
||||
|
||||
|
||||
class FileTransferBase(protocol.Protocol):
|
||||
|
||||
versions = (3, )
|
||||
|
||||
packetTypes = {}
|
||||
|
||||
def __init__(self):
|
||||
self.buf = ''
|
||||
self.otherVersion = None # this gets set
|
||||
|
||||
def sendPacket(self, kind, data):
|
||||
self.transport.write(struct.pack('!LB', len(data)+1, kind) + data)
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buf += data
|
||||
while len(self.buf) > 5:
|
||||
length, kind = struct.unpack('!LB', self.buf[:5])
|
||||
if len(self.buf) < 4 + length:
|
||||
return
|
||||
data, self.buf = self.buf[5:4+length], self.buf[4+length:]
|
||||
packetType = self.packetTypes.get(kind, None)
|
||||
if not packetType:
|
||||
log.msg('no packet type for', kind)
|
||||
continue
|
||||
f = getattr(self, 'packet_%s' % packetType, None)
|
||||
if not f:
|
||||
log.msg('not implemented: %s' % packetType)
|
||||
log.msg(repr(data[4:]))
|
||||
reqId, = struct.unpack('!L', data[:4])
|
||||
self._sendStatus(reqId, FX_OP_UNSUPPORTED,
|
||||
"don't understand %s" % packetType)
|
||||
#XXX not implemented
|
||||
continue
|
||||
try:
|
||||
f(data)
|
||||
except Exception:
|
||||
log.err()
|
||||
continue
|
||||
|
||||
|
||||
def _parseAttributes(self, data):
|
||||
flags ,= struct.unpack('!L', data[:4])
|
||||
attrs = {}
|
||||
data = data[4:]
|
||||
if flags & FILEXFER_ATTR_SIZE == FILEXFER_ATTR_SIZE:
|
||||
size ,= struct.unpack('!Q', data[:8])
|
||||
attrs['size'] = size
|
||||
data = data[8:]
|
||||
if flags & FILEXFER_ATTR_OWNERGROUP == FILEXFER_ATTR_OWNERGROUP:
|
||||
uid, gid = struct.unpack('!2L', data[:8])
|
||||
attrs['uid'] = uid
|
||||
attrs['gid'] = gid
|
||||
data = data[8:]
|
||||
if flags & FILEXFER_ATTR_PERMISSIONS == FILEXFER_ATTR_PERMISSIONS:
|
||||
perms ,= struct.unpack('!L', data[:4])
|
||||
attrs['permissions'] = perms
|
||||
data = data[4:]
|
||||
if flags & FILEXFER_ATTR_ACMODTIME == FILEXFER_ATTR_ACMODTIME:
|
||||
atime, mtime = struct.unpack('!2L', data[:8])
|
||||
attrs['atime'] = atime
|
||||
attrs['mtime'] = mtime
|
||||
data = data[8:]
|
||||
if flags & FILEXFER_ATTR_EXTENDED == FILEXFER_ATTR_EXTENDED:
|
||||
extended_count ,= struct.unpack('!L', data[:4])
|
||||
data = data[4:]
|
||||
for i in xrange(extended_count):
|
||||
extended_type, data = getNS(data)
|
||||
extended_data, data = getNS(data)
|
||||
attrs['ext_%s' % extended_type] = extended_data
|
||||
return attrs, data
|
||||
|
||||
def _packAttributes(self, attrs):
|
||||
flags = 0
|
||||
data = ''
|
||||
if 'size' in attrs:
|
||||
data += struct.pack('!Q', attrs['size'])
|
||||
flags |= FILEXFER_ATTR_SIZE
|
||||
if 'uid' in attrs and 'gid' in attrs:
|
||||
data += struct.pack('!2L', attrs['uid'], attrs['gid'])
|
||||
flags |= FILEXFER_ATTR_OWNERGROUP
|
||||
if 'permissions' in attrs:
|
||||
data += struct.pack('!L', attrs['permissions'])
|
||||
flags |= FILEXFER_ATTR_PERMISSIONS
|
||||
if 'atime' in attrs and 'mtime' in attrs:
|
||||
data += struct.pack('!2L', attrs['atime'], attrs['mtime'])
|
||||
flags |= FILEXFER_ATTR_ACMODTIME
|
||||
extended = []
|
||||
for k in attrs:
|
||||
if k.startswith('ext_'):
|
||||
ext_type = NS(k[4:])
|
||||
ext_data = NS(attrs[k])
|
||||
extended.append(ext_type+ext_data)
|
||||
if extended:
|
||||
data += struct.pack('!L', len(extended))
|
||||
data += ''.join(extended)
|
||||
flags |= FILEXFER_ATTR_EXTENDED
|
||||
return struct.pack('!L', flags) + data
|
||||
|
||||
class FileTransferServer(FileTransferBase):
|
||||
|
||||
def __init__(self, data=None, avatar=None):
|
||||
FileTransferBase.__init__(self)
|
||||
self.client = ISFTPServer(avatar) # yay interfaces
|
||||
self.openFiles = {}
|
||||
self.openDirs = {}
|
||||
|
||||
def packet_INIT(self, data):
|
||||
version ,= struct.unpack('!L', data[:4])
|
||||
self.version = min(list(self.versions) + [version])
|
||||
data = data[4:]
|
||||
ext = {}
|
||||
while data:
|
||||
ext_name, data = getNS(data)
|
||||
ext_data, data = getNS(data)
|
||||
ext[ext_name] = ext_data
|
||||
our_ext = self.client.gotVersion(version, ext)
|
||||
our_ext_data = ""
|
||||
for (k,v) in our_ext.items():
|
||||
our_ext_data += NS(k) + NS(v)
|
||||
self.sendPacket(FXP_VERSION, struct.pack('!L', self.version) + \
|
||||
our_ext_data)
|
||||
|
||||
def packet_OPEN(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
filename, data = getNS(data)
|
||||
flags ,= struct.unpack('!L', data[:4])
|
||||
data = data[4:]
|
||||
attrs, data = self._parseAttributes(data)
|
||||
assert data == '', 'still have data in OPEN: %s' % repr(data)
|
||||
d = defer.maybeDeferred(self.client.openFile, filename, flags, attrs)
|
||||
d.addCallback(self._cbOpenFile, requestId)
|
||||
d.addErrback(self._ebStatus, requestId, "open failed")
|
||||
|
||||
def _cbOpenFile(self, fileObj, requestId):
|
||||
fileId = str(hash(fileObj))
|
||||
if fileId in self.openFiles:
|
||||
raise KeyError, 'id already open'
|
||||
self.openFiles[fileId] = fileObj
|
||||
self.sendPacket(FXP_HANDLE, requestId + NS(fileId))
|
||||
|
||||
def packet_CLOSE(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
handle, data = getNS(data)
|
||||
assert data == '', 'still have data in CLOSE: %s' % repr(data)
|
||||
if handle in self.openFiles:
|
||||
fileObj = self.openFiles[handle]
|
||||
d = defer.maybeDeferred(fileObj.close)
|
||||
d.addCallback(self._cbClose, handle, requestId)
|
||||
d.addErrback(self._ebStatus, requestId, "close failed")
|
||||
elif handle in self.openDirs:
|
||||
dirObj = self.openDirs[handle][0]
|
||||
d = defer.maybeDeferred(dirObj.close)
|
||||
d.addCallback(self._cbClose, handle, requestId, 1)
|
||||
d.addErrback(self._ebStatus, requestId, "close failed")
|
||||
else:
|
||||
self._ebClose(failure.Failure(KeyError()), requestId)
|
||||
|
||||
def _cbClose(self, result, handle, requestId, isDir = 0):
|
||||
if isDir:
|
||||
del self.openDirs[handle]
|
||||
else:
|
||||
del self.openFiles[handle]
|
||||
self._sendStatus(requestId, FX_OK, 'file closed')
|
||||
|
||||
def packet_READ(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
handle, data = getNS(data)
|
||||
(offset, length), data = struct.unpack('!QL', data[:12]), data[12:]
|
||||
assert data == '', 'still have data in READ: %s' % repr(data)
|
||||
if handle not in self.openFiles:
|
||||
self._ebRead(failure.Failure(KeyError()), requestId)
|
||||
else:
|
||||
fileObj = self.openFiles[handle]
|
||||
d = defer.maybeDeferred(fileObj.readChunk, offset, length)
|
||||
d.addCallback(self._cbRead, requestId)
|
||||
d.addErrback(self._ebStatus, requestId, "read failed")
|
||||
|
||||
def _cbRead(self, result, requestId):
|
||||
if result == '': # python's read will return this for EOF
|
||||
raise EOFError()
|
||||
self.sendPacket(FXP_DATA, requestId + NS(result))
|
||||
|
||||
def packet_WRITE(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
handle, data = getNS(data)
|
||||
offset, = struct.unpack('!Q', data[:8])
|
||||
data = data[8:]
|
||||
writeData, data = getNS(data)
|
||||
assert data == '', 'still have data in WRITE: %s' % repr(data)
|
||||
if handle not in self.openFiles:
|
||||
self._ebWrite(failure.Failure(KeyError()), requestId)
|
||||
else:
|
||||
fileObj = self.openFiles[handle]
|
||||
d = defer.maybeDeferred(fileObj.writeChunk, offset, writeData)
|
||||
d.addCallback(self._cbStatus, requestId, "write succeeded")
|
||||
d.addErrback(self._ebStatus, requestId, "write failed")
|
||||
|
||||
def packet_REMOVE(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
filename, data = getNS(data)
|
||||
assert data == '', 'still have data in REMOVE: %s' % repr(data)
|
||||
d = defer.maybeDeferred(self.client.removeFile, filename)
|
||||
d.addCallback(self._cbStatus, requestId, "remove succeeded")
|
||||
d.addErrback(self._ebStatus, requestId, "remove failed")
|
||||
|
||||
def packet_RENAME(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
oldPath, data = getNS(data)
|
||||
newPath, data = getNS(data)
|
||||
assert data == '', 'still have data in RENAME: %s' % repr(data)
|
||||
d = defer.maybeDeferred(self.client.renameFile, oldPath, newPath)
|
||||
d.addCallback(self._cbStatus, requestId, "rename succeeded")
|
||||
d.addErrback(self._ebStatus, requestId, "rename failed")
|
||||
|
||||
def packet_MKDIR(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
path, data = getNS(data)
|
||||
attrs, data = self._parseAttributes(data)
|
||||
assert data == '', 'still have data in MKDIR: %s' % repr(data)
|
||||
d = defer.maybeDeferred(self.client.makeDirectory, path, attrs)
|
||||
d.addCallback(self._cbStatus, requestId, "mkdir succeeded")
|
||||
d.addErrback(self._ebStatus, requestId, "mkdir failed")
|
||||
|
||||
def packet_RMDIR(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
path, data = getNS(data)
|
||||
assert data == '', 'still have data in RMDIR: %s' % repr(data)
|
||||
d = defer.maybeDeferred(self.client.removeDirectory, path)
|
||||
d.addCallback(self._cbStatus, requestId, "rmdir succeeded")
|
||||
d.addErrback(self._ebStatus, requestId, "rmdir failed")
|
||||
|
||||
def packet_OPENDIR(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
path, data = getNS(data)
|
||||
assert data == '', 'still have data in OPENDIR: %s' % repr(data)
|
||||
d = defer.maybeDeferred(self.client.openDirectory, path)
|
||||
d.addCallback(self._cbOpenDirectory, requestId)
|
||||
d.addErrback(self._ebStatus, requestId, "opendir failed")
|
||||
|
||||
def _cbOpenDirectory(self, dirObj, requestId):
|
||||
handle = str(hash(dirObj))
|
||||
if handle in self.openDirs:
|
||||
raise KeyError, "already opened this directory"
|
||||
self.openDirs[handle] = [dirObj, iter(dirObj)]
|
||||
self.sendPacket(FXP_HANDLE, requestId + NS(handle))
|
||||
|
||||
def packet_READDIR(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
handle, data = getNS(data)
|
||||
assert data == '', 'still have data in READDIR: %s' % repr(data)
|
||||
if handle not in self.openDirs:
|
||||
self._ebStatus(failure.Failure(KeyError()), requestId)
|
||||
else:
|
||||
dirObj, dirIter = self.openDirs[handle]
|
||||
d = defer.maybeDeferred(self._scanDirectory, dirIter, [])
|
||||
d.addCallback(self._cbSendDirectory, requestId)
|
||||
d.addErrback(self._ebStatus, requestId, "scan directory failed")
|
||||
|
||||
def _scanDirectory(self, dirIter, f):
|
||||
while len(f) < 250:
|
||||
try:
|
||||
info = dirIter.next()
|
||||
except StopIteration:
|
||||
if not f:
|
||||
raise EOFError
|
||||
return f
|
||||
if isinstance(info, defer.Deferred):
|
||||
info.addCallback(self._cbScanDirectory, dirIter, f)
|
||||
return
|
||||
else:
|
||||
f.append(info)
|
||||
return f
|
||||
|
||||
def _cbScanDirectory(self, result, dirIter, f):
|
||||
f.append(result)
|
||||
return self._scanDirectory(dirIter, f)
|
||||
|
||||
def _cbSendDirectory(self, result, requestId):
|
||||
data = ''
|
||||
for (filename, longname, attrs) in result:
|
||||
data += NS(filename)
|
||||
data += NS(longname)
|
||||
data += self._packAttributes(attrs)
|
||||
self.sendPacket(FXP_NAME, requestId +
|
||||
struct.pack('!L', len(result))+data)
|
||||
|
||||
def packet_STAT(self, data, followLinks = 1):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
path, data = getNS(data)
|
||||
assert data == '', 'still have data in STAT/LSTAT: %s' % repr(data)
|
||||
d = defer.maybeDeferred(self.client.getAttrs, path, followLinks)
|
||||
d.addCallback(self._cbStat, requestId)
|
||||
d.addErrback(self._ebStatus, requestId, 'stat/lstat failed')
|
||||
|
||||
def packet_LSTAT(self, data):
|
||||
self.packet_STAT(data, 0)
|
||||
|
||||
def packet_FSTAT(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
handle, data = getNS(data)
|
||||
assert data == '', 'still have data in FSTAT: %s' % repr(data)
|
||||
if handle not in self.openFiles:
|
||||
self._ebStatus(failure.Failure(KeyError('%s not in self.openFiles'
|
||||
% handle)), requestId)
|
||||
else:
|
||||
fileObj = self.openFiles[handle]
|
||||
d = defer.maybeDeferred(fileObj.getAttrs)
|
||||
d.addCallback(self._cbStat, requestId)
|
||||
d.addErrback(self._ebStatus, requestId, 'fstat failed')
|
||||
|
||||
def _cbStat(self, result, requestId):
|
||||
data = requestId + self._packAttributes(result)
|
||||
self.sendPacket(FXP_ATTRS, data)
|
||||
|
||||
def packet_SETSTAT(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
path, data = getNS(data)
|
||||
attrs, data = self._parseAttributes(data)
|
||||
if data != '':
|
||||
log.msg('WARN: still have data in SETSTAT: %s' % repr(data))
|
||||
d = defer.maybeDeferred(self.client.setAttrs, path, attrs)
|
||||
d.addCallback(self._cbStatus, requestId, 'setstat succeeded')
|
||||
d.addErrback(self._ebStatus, requestId, 'setstat failed')
|
||||
|
||||
def packet_FSETSTAT(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
handle, data = getNS(data)
|
||||
attrs, data = self._parseAttributes(data)
|
||||
assert data == '', 'still have data in FSETSTAT: %s' % repr(data)
|
||||
if handle not in self.openFiles:
|
||||
self._ebStatus(failure.Failure(KeyError()), requestId)
|
||||
else:
|
||||
fileObj = self.openFiles[handle]
|
||||
d = defer.maybeDeferred(fileObj.setAttrs, attrs)
|
||||
d.addCallback(self._cbStatus, requestId, 'fsetstat succeeded')
|
||||
d.addErrback(self._ebStatus, requestId, 'fsetstat failed')
|
||||
|
||||
def packet_READLINK(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
path, data = getNS(data)
|
||||
assert data == '', 'still have data in READLINK: %s' % repr(data)
|
||||
d = defer.maybeDeferred(self.client.readLink, path)
|
||||
d.addCallback(self._cbReadLink, requestId)
|
||||
d.addErrback(self._ebStatus, requestId, 'readlink failed')
|
||||
|
||||
def _cbReadLink(self, result, requestId):
|
||||
self._cbSendDirectory([(result, '', {})], requestId)
|
||||
|
||||
def packet_SYMLINK(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
linkPath, data = getNS(data)
|
||||
targetPath, data = getNS(data)
|
||||
d = defer.maybeDeferred(self.client.makeLink, linkPath, targetPath)
|
||||
d.addCallback(self._cbStatus, requestId, 'symlink succeeded')
|
||||
d.addErrback(self._ebStatus, requestId, 'symlink failed')
|
||||
|
||||
def packet_REALPATH(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
path, data = getNS(data)
|
||||
assert data == '', 'still have data in REALPATH: %s' % repr(data)
|
||||
d = defer.maybeDeferred(self.client.realPath, path)
|
||||
d.addCallback(self._cbReadLink, requestId) # same return format
|
||||
d.addErrback(self._ebStatus, requestId, 'realpath failed')
|
||||
|
||||
def packet_EXTENDED(self, data):
|
||||
requestId = data[:4]
|
||||
data = data[4:]
|
||||
extName, extData = getNS(data)
|
||||
d = defer.maybeDeferred(self.client.extendedRequest, extName, extData)
|
||||
d.addCallback(self._cbExtended, requestId)
|
||||
d.addErrback(self._ebStatus, requestId, 'extended %s failed' % extName)
|
||||
|
||||
def _cbExtended(self, data, requestId):
|
||||
self.sendPacket(FXP_EXTENDED_REPLY, requestId + data)
|
||||
|
||||
def _cbStatus(self, result, requestId, msg = "request succeeded"):
|
||||
self._sendStatus(requestId, FX_OK, msg)
|
||||
|
||||
def _ebStatus(self, reason, requestId, msg = "request failed"):
|
||||
code = FX_FAILURE
|
||||
message = msg
|
||||
if reason.type in (IOError, OSError):
|
||||
if reason.value.errno == errno.ENOENT: # no such file
|
||||
code = FX_NO_SUCH_FILE
|
||||
message = reason.value.strerror
|
||||
elif reason.value.errno == errno.EACCES: # permission denied
|
||||
code = FX_PERMISSION_DENIED
|
||||
message = reason.value.strerror
|
||||
elif reason.value.errno == errno.EEXIST:
|
||||
code = FX_FILE_ALREADY_EXISTS
|
||||
else:
|
||||
log.err(reason)
|
||||
elif reason.type == EOFError: # EOF
|
||||
code = FX_EOF
|
||||
if reason.value.args:
|
||||
message = reason.value.args[0]
|
||||
elif reason.type == NotImplementedError:
|
||||
code = FX_OP_UNSUPPORTED
|
||||
if reason.value.args:
|
||||
message = reason.value.args[0]
|
||||
elif reason.type == SFTPError:
|
||||
code = reason.value.code
|
||||
message = reason.value.message
|
||||
else:
|
||||
log.err(reason)
|
||||
self._sendStatus(requestId, code, message)
|
||||
|
||||
def _sendStatus(self, requestId, code, message, lang = ''):
|
||||
"""
|
||||
Helper method to send a FXP_STATUS message.
|
||||
"""
|
||||
data = requestId + struct.pack('!L', code)
|
||||
data += NS(message)
|
||||
data += NS(lang)
|
||||
self.sendPacket(FXP_STATUS, data)
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Clean all opened files and directories.
|
||||
"""
|
||||
for fileObj in self.openFiles.values():
|
||||
fileObj.close()
|
||||
self.openFiles = {}
|
||||
for (dirObj, dirIter) in self.openDirs.values():
|
||||
dirObj.close()
|
||||
self.openDirs = {}
|
||||
|
||||
|
||||
|
||||
class FileTransferClient(FileTransferBase):
|
||||
|
||||
def __init__(self, extData = {}):
|
||||
"""
|
||||
@param extData: a dict of extended_name : extended_data items
|
||||
to be sent to the server.
|
||||
"""
|
||||
FileTransferBase.__init__(self)
|
||||
self.extData = {}
|
||||
self.counter = 0
|
||||
self.openRequests = {} # id -> Deferred
|
||||
self.wasAFile = {} # Deferred -> 1 TERRIBLE HACK
|
||||
|
||||
def connectionMade(self):
|
||||
data = struct.pack('!L', max(self.versions))
|
||||
for k,v in self.extData.itervalues():
|
||||
data += NS(k) + NS(v)
|
||||
self.sendPacket(FXP_INIT, data)
|
||||
|
||||
def _sendRequest(self, msg, data):
|
||||
data = struct.pack('!L', self.counter) + data
|
||||
d = defer.Deferred()
|
||||
self.openRequests[self.counter] = d
|
||||
self.counter += 1
|
||||
self.sendPacket(msg, data)
|
||||
return d
|
||||
|
||||
def _parseRequest(self, data):
|
||||
(id,) = struct.unpack('!L', data[:4])
|
||||
d = self.openRequests[id]
|
||||
del self.openRequests[id]
|
||||
return d, data[4:]
|
||||
|
||||
def openFile(self, filename, flags, attrs):
|
||||
"""
|
||||
Open a file.
|
||||
|
||||
This method returns a L{Deferred} that is called back with an object
|
||||
that provides the L{ISFTPFile} interface.
|
||||
|
||||
@param filename: a string representing the file to open.
|
||||
|
||||
@param flags: a integer of the flags to open the file with, ORed together.
|
||||
The flags and their values are listed at the bottom of this file.
|
||||
|
||||
@param attrs: a list of attributes to open the file with. It is a
|
||||
dictionary, consisting of 0 or more keys. The possible keys are::
|
||||
|
||||
size: the size of the file in bytes
|
||||
uid: the user ID of the file as an integer
|
||||
gid: the group ID of the file as an integer
|
||||
permissions: the permissions of the file with as an integer.
|
||||
the bit representation of this field is defined by POSIX.
|
||||
atime: the access time of the file as seconds since the epoch.
|
||||
mtime: the modification time of the file as seconds since the epoch.
|
||||
ext_*: extended attributes. The server is not required to
|
||||
understand this, but it may.
|
||||
|
||||
NOTE: there is no way to indicate text or binary files. it is up
|
||||
to the SFTP client to deal with this.
|
||||
"""
|
||||
data = NS(filename) + struct.pack('!L', flags) + self._packAttributes(attrs)
|
||||
d = self._sendRequest(FXP_OPEN, data)
|
||||
self.wasAFile[d] = (1, filename) # HACK
|
||||
return d
|
||||
|
||||
def removeFile(self, filename):
|
||||
"""
|
||||
Remove the given file.
|
||||
|
||||
This method returns a Deferred that is called back when it succeeds.
|
||||
|
||||
@param filename: the name of the file as a string.
|
||||
"""
|
||||
return self._sendRequest(FXP_REMOVE, NS(filename))
|
||||
|
||||
def renameFile(self, oldpath, newpath):
|
||||
"""
|
||||
Rename the given file.
|
||||
|
||||
This method returns a Deferred that is called back when it succeeds.
|
||||
|
||||
@param oldpath: the current location of the file.
|
||||
@param newpath: the new file name.
|
||||
"""
|
||||
return self._sendRequest(FXP_RENAME, NS(oldpath)+NS(newpath))
|
||||
|
||||
def makeDirectory(self, path, attrs):
|
||||
"""
|
||||
Make a directory.
|
||||
|
||||
This method returns a Deferred that is called back when it is
|
||||
created.
|
||||
|
||||
@param path: the name of the directory to create as a string.
|
||||
|
||||
@param attrs: a dictionary of attributes to create the directory
|
||||
with. Its meaning is the same as the attrs in the openFile method.
|
||||
"""
|
||||
return self._sendRequest(FXP_MKDIR, NS(path)+self._packAttributes(attrs))
|
||||
|
||||
def removeDirectory(self, path):
|
||||
"""
|
||||
Remove a directory (non-recursively)
|
||||
|
||||
It is an error to remove a directory that has files or directories in
|
||||
it.
|
||||
|
||||
This method returns a Deferred that is called back when it is removed.
|
||||
|
||||
@param path: the directory to remove.
|
||||
"""
|
||||
return self._sendRequest(FXP_RMDIR, NS(path))
|
||||
|
||||
def openDirectory(self, path):
|
||||
"""
|
||||
Open a directory for scanning.
|
||||
|
||||
This method returns a Deferred that is called back with an iterable
|
||||
object that has a close() method.
|
||||
|
||||
The close() method is called when the client is finished reading
|
||||
from the directory. At this point, the iterable will no longer
|
||||
be used.
|
||||
|
||||
The iterable returns triples of the form (filename, longname, attrs)
|
||||
or a Deferred that returns the same. The sequence must support
|
||||
__getitem__, but otherwise may be any 'sequence-like' object.
|
||||
|
||||
filename is the name of the file relative to the directory.
|
||||
logname is an expanded format of the filename. The recommended format
|
||||
is:
|
||||
-rwxr-xr-x 1 mjos staff 348911 Mar 25 14:29 t-filexfer
|
||||
1234567890 123 12345678 12345678 12345678 123456789012
|
||||
|
||||
The first line is sample output, the second is the length of the field.
|
||||
The fields are: permissions, link count, user owner, group owner,
|
||||
size in bytes, modification time.
|
||||
|
||||
attrs is a dictionary in the format of the attrs argument to openFile.
|
||||
|
||||
@param path: the directory to open.
|
||||
"""
|
||||
d = self._sendRequest(FXP_OPENDIR, NS(path))
|
||||
self.wasAFile[d] = (0, path)
|
||||
return d
|
||||
|
||||
def getAttrs(self, path, followLinks=0):
|
||||
"""
|
||||
Return the attributes for the given path.
|
||||
|
||||
This method returns a dictionary in the same format as the attrs
|
||||
argument to openFile or a Deferred that is called back with same.
|
||||
|
||||
@param path: the path to return attributes for as a string.
|
||||
@param followLinks: a boolean. if it is True, follow symbolic links
|
||||
and return attributes for the real path at the base. if it is False,
|
||||
return attributes for the specified path.
|
||||
"""
|
||||
if followLinks: m = FXP_STAT
|
||||
else: m = FXP_LSTAT
|
||||
return self._sendRequest(m, NS(path))
|
||||
|
||||
def setAttrs(self, path, attrs):
|
||||
"""
|
||||
Set the attributes for the path.
|
||||
|
||||
This method returns when the attributes are set or a Deferred that is
|
||||
called back when they are.
|
||||
|
||||
@param path: the path to set attributes for as a string.
|
||||
@param attrs: a dictionary in the same format as the attrs argument to
|
||||
openFile.
|
||||
"""
|
||||
data = NS(path) + self._packAttributes(attrs)
|
||||
return self._sendRequest(FXP_SETSTAT, data)
|
||||
|
||||
def readLink(self, path):
|
||||
"""
|
||||
Find the root of a set of symbolic links.
|
||||
|
||||
This method returns the target of the link, or a Deferred that
|
||||
returns the same.
|
||||
|
||||
@param path: the path of the symlink to read.
|
||||
"""
|
||||
d = self._sendRequest(FXP_READLINK, NS(path))
|
||||
return d.addCallback(self._cbRealPath)
|
||||
|
||||
def makeLink(self, linkPath, targetPath):
|
||||
"""
|
||||
Create a symbolic link.
|
||||
|
||||
This method returns when the link is made, or a Deferred that
|
||||
returns the same.
|
||||
|
||||
@param linkPath: the pathname of the symlink as a string
|
||||
@param targetPath: the path of the target of the link as a string.
|
||||
"""
|
||||
return self._sendRequest(FXP_SYMLINK, NS(linkPath)+NS(targetPath))
|
||||
|
||||
def realPath(self, path):
|
||||
"""
|
||||
Convert any path to an absolute path.
|
||||
|
||||
This method returns the absolute path as a string, or a Deferred
|
||||
that returns the same.
|
||||
|
||||
@param path: the path to convert as a string.
|
||||
"""
|
||||
d = self._sendRequest(FXP_REALPATH, NS(path))
|
||||
return d.addCallback(self._cbRealPath)
|
||||
|
||||
def _cbRealPath(self, result):
|
||||
name, longname, attrs = result[0]
|
||||
return name
|
||||
|
||||
def extendedRequest(self, request, data):
|
||||
"""
|
||||
Make an extended request of the server.
|
||||
|
||||
The method returns a Deferred that is called back with
|
||||
the result of the extended request.
|
||||
|
||||
@param request: the name of the extended request to make.
|
||||
@param data: any other data that goes along with the request.
|
||||
"""
|
||||
return self._sendRequest(FXP_EXTENDED, NS(request) + data)
|
||||
|
||||
def packet_VERSION(self, data):
|
||||
version, = struct.unpack('!L', data[:4])
|
||||
data = data[4:]
|
||||
d = {}
|
||||
while data:
|
||||
k, data = getNS(data)
|
||||
v, data = getNS(data)
|
||||
d[k]=v
|
||||
self.version = version
|
||||
self.gotServerVersion(version, d)
|
||||
|
||||
def packet_STATUS(self, data):
|
||||
d, data = self._parseRequest(data)
|
||||
code, = struct.unpack('!L', data[:4])
|
||||
data = data[4:]
|
||||
if len(data) >= 4:
|
||||
msg, data = getNS(data)
|
||||
if len(data) >= 4:
|
||||
lang, data = getNS(data)
|
||||
else:
|
||||
lang = ''
|
||||
else:
|
||||
msg = ''
|
||||
lang = ''
|
||||
if code == FX_OK:
|
||||
d.callback((msg, lang))
|
||||
elif code == FX_EOF:
|
||||
d.errback(EOFError(msg))
|
||||
elif code == FX_OP_UNSUPPORTED:
|
||||
d.errback(NotImplementedError(msg))
|
||||
else:
|
||||
d.errback(SFTPError(code, msg, lang))
|
||||
|
||||
def packet_HANDLE(self, data):
|
||||
d, data = self._parseRequest(data)
|
||||
isFile, name = self.wasAFile.pop(d)
|
||||
if isFile:
|
||||
cb = ClientFile(self, getNS(data)[0])
|
||||
else:
|
||||
cb = ClientDirectory(self, getNS(data)[0])
|
||||
cb.name = name
|
||||
d.callback(cb)
|
||||
|
||||
def packet_DATA(self, data):
|
||||
d, data = self._parseRequest(data)
|
||||
d.callback(getNS(data)[0])
|
||||
|
||||
def packet_NAME(self, data):
|
||||
d, data = self._parseRequest(data)
|
||||
count, = struct.unpack('!L', data[:4])
|
||||
data = data[4:]
|
||||
files = []
|
||||
for i in range(count):
|
||||
filename, data = getNS(data)
|
||||
longname, data = getNS(data)
|
||||
attrs, data = self._parseAttributes(data)
|
||||
files.append((filename, longname, attrs))
|
||||
d.callback(files)
|
||||
|
||||
def packet_ATTRS(self, data):
|
||||
d, data = self._parseRequest(data)
|
||||
d.callback(self._parseAttributes(data)[0])
|
||||
|
||||
def packet_EXTENDED_REPLY(self, data):
|
||||
d, data = self._parseRequest(data)
|
||||
d.callback(data)
|
||||
|
||||
def gotServerVersion(self, serverVersion, extData):
|
||||
"""
|
||||
Called when the client sends their version info.
|
||||
|
||||
@param otherVersion: an integer representing the version of the SFTP
|
||||
protocol they are claiming.
|
||||
@param extData: a dictionary of extended_name : extended_data items.
|
||||
These items are sent by the client to indicate additional features.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@implementer(ISFTPFile)
|
||||
class ClientFile:
|
||||
def __init__(self, parent, handle):
|
||||
self.parent = parent
|
||||
self.handle = NS(handle)
|
||||
|
||||
def close(self):
|
||||
return self.parent._sendRequest(FXP_CLOSE, self.handle)
|
||||
|
||||
def readChunk(self, offset, length):
|
||||
data = self.handle + struct.pack("!QL", offset, length)
|
||||
return self.parent._sendRequest(FXP_READ, data)
|
||||
|
||||
def writeChunk(self, offset, chunk):
|
||||
data = self.handle + struct.pack("!Q", offset) + NS(chunk)
|
||||
return self.parent._sendRequest(FXP_WRITE, data)
|
||||
|
||||
def getAttrs(self):
|
||||
return self.parent._sendRequest(FXP_FSTAT, self.handle)
|
||||
|
||||
def setAttrs(self, attrs):
|
||||
data = self.handle + self.parent._packAttributes(attrs)
|
||||
return self.parent._sendRequest(FXP_FSTAT, data)
|
||||
|
||||
class ClientDirectory:
|
||||
|
||||
def __init__(self, parent, handle):
|
||||
self.parent = parent
|
||||
self.handle = NS(handle)
|
||||
self.filesCache = []
|
||||
|
||||
def read(self):
|
||||
d = self.parent._sendRequest(FXP_READDIR, self.handle)
|
||||
return d
|
||||
|
||||
def close(self):
|
||||
return self.parent._sendRequest(FXP_CLOSE, self.handle)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if self.filesCache:
|
||||
return self.filesCache.pop(0)
|
||||
d = self.read()
|
||||
d.addCallback(self._cbReadDir)
|
||||
d.addErrback(self._ebReadDir)
|
||||
return d
|
||||
|
||||
def _cbReadDir(self, names):
|
||||
self.filesCache = names[1:]
|
||||
return names[0]
|
||||
|
||||
def _ebReadDir(self, reason):
|
||||
reason.trap(EOFError)
|
||||
def _():
|
||||
raise StopIteration
|
||||
self.next = _
|
||||
return reason
|
||||
|
||||
|
||||
class SFTPError(Exception):
|
||||
|
||||
def __init__(self, errorCode, errorMessage, lang = ''):
|
||||
Exception.__init__(self)
|
||||
self.code = errorCode
|
||||
self._message = errorMessage
|
||||
self.lang = lang
|
||||
|
||||
|
||||
def message(self):
|
||||
"""
|
||||
A string received over the network that explains the error to a human.
|
||||
"""
|
||||
# Python 2.6 deprecates assigning to the 'message' attribute of an
|
||||
# exception. We define this read-only property here in order to
|
||||
# prevent the warning about deprecation while maintaining backwards
|
||||
# compatibility with object clients that rely on the 'message'
|
||||
# attribute being set correctly. See bug #3897.
|
||||
return self._message
|
||||
message = property(message)
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return 'SFTPError %s: %s' % (self.code, self.message)
|
||||
|
||||
FXP_INIT = 1
|
||||
FXP_VERSION = 2
|
||||
FXP_OPEN = 3
|
||||
FXP_CLOSE = 4
|
||||
FXP_READ = 5
|
||||
FXP_WRITE = 6
|
||||
FXP_LSTAT = 7
|
||||
FXP_FSTAT = 8
|
||||
FXP_SETSTAT = 9
|
||||
FXP_FSETSTAT = 10
|
||||
FXP_OPENDIR = 11
|
||||
FXP_READDIR = 12
|
||||
FXP_REMOVE = 13
|
||||
FXP_MKDIR = 14
|
||||
FXP_RMDIR = 15
|
||||
FXP_REALPATH = 16
|
||||
FXP_STAT = 17
|
||||
FXP_RENAME = 18
|
||||
FXP_READLINK = 19
|
||||
FXP_SYMLINK = 20
|
||||
FXP_STATUS = 101
|
||||
FXP_HANDLE = 102
|
||||
FXP_DATA = 103
|
||||
FXP_NAME = 104
|
||||
FXP_ATTRS = 105
|
||||
FXP_EXTENDED = 200
|
||||
FXP_EXTENDED_REPLY = 201
|
||||
|
||||
FILEXFER_ATTR_SIZE = 0x00000001
|
||||
FILEXFER_ATTR_UIDGID = 0x00000002
|
||||
FILEXFER_ATTR_OWNERGROUP = FILEXFER_ATTR_UIDGID
|
||||
FILEXFER_ATTR_PERMISSIONS = 0x00000004
|
||||
FILEXFER_ATTR_ACMODTIME = 0x00000008
|
||||
FILEXFER_ATTR_EXTENDED = 0x80000000L
|
||||
|
||||
FILEXFER_TYPE_REGULAR = 1
|
||||
FILEXFER_TYPE_DIRECTORY = 2
|
||||
FILEXFER_TYPE_SYMLINK = 3
|
||||
FILEXFER_TYPE_SPECIAL = 4
|
||||
FILEXFER_TYPE_UNKNOWN = 5
|
||||
|
||||
FXF_READ = 0x00000001
|
||||
FXF_WRITE = 0x00000002
|
||||
FXF_APPEND = 0x00000004
|
||||
FXF_CREAT = 0x00000008
|
||||
FXF_TRUNC = 0x00000010
|
||||
FXF_EXCL = 0x00000020
|
||||
FXF_TEXT = 0x00000040
|
||||
|
||||
FX_OK = 0
|
||||
FX_EOF = 1
|
||||
FX_NO_SUCH_FILE = 2
|
||||
FX_PERMISSION_DENIED = 3
|
||||
FX_FAILURE = 4
|
||||
FX_BAD_MESSAGE = 5
|
||||
FX_NO_CONNECTION = 6
|
||||
FX_CONNECTION_LOST = 7
|
||||
FX_OP_UNSUPPORTED = 8
|
||||
FX_FILE_ALREADY_EXISTS = 11
|
||||
# http://tools.ietf.org/wg/secsh/draft-ietf-secsh-filexfer/ defines more
|
||||
# useful error codes, but so far OpenSSH doesn't implement them. We use them
|
||||
# internally for clarity, but for now define them all as FX_FAILURE to be
|
||||
# compatible with existing software.
|
||||
FX_NOT_A_DIRECTORY = FX_FAILURE
|
||||
FX_FILE_IS_A_DIRECTORY = FX_FAILURE
|
||||
|
||||
|
||||
# initialize FileTransferBase.packetTypes:
|
||||
g = globals()
|
||||
for name in g.keys():
|
||||
if name.startswith('FXP_'):
|
||||
value = g[name]
|
||||
FileTransferBase.packetTypes[value] = name[4:]
|
||||
del g, name, value
|
@ -0,0 +1,236 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains the implementation of the TCP forwarding, which allows
|
||||
clients and servers to forward arbitrary TCP data across the connection.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.internet import protocol, reactor
|
||||
from twisted.internet.endpoints import HostnameEndpoint, connectProtocol
|
||||
from twisted.python import log
|
||||
|
||||
import common, channel
|
||||
|
||||
class SSHListenForwardingFactory(protocol.Factory):
|
||||
def __init__(self, connection, hostport, klass):
|
||||
self.conn = connection
|
||||
self.hostport = hostport # tuple
|
||||
self.klass = klass
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
channel = self.klass(conn = self.conn)
|
||||
client = SSHForwardingClient(channel)
|
||||
channel.client = client
|
||||
addrTuple = (addr.host, addr.port)
|
||||
channelOpenData = packOpen_direct_tcpip(self.hostport, addrTuple)
|
||||
self.conn.openChannel(channel, channelOpenData)
|
||||
return client
|
||||
|
||||
class SSHListenForwardingChannel(channel.SSHChannel):
|
||||
|
||||
def channelOpen(self, specificData):
|
||||
log.msg('opened forwarding channel %s' % self.id)
|
||||
if len(self.client.buf)>1:
|
||||
b = self.client.buf[1:]
|
||||
self.write(b)
|
||||
self.client.buf = ''
|
||||
|
||||
def openFailed(self, reason):
|
||||
self.closed()
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.client.transport.write(data)
|
||||
|
||||
def eofReceived(self):
|
||||
self.client.transport.loseConnection()
|
||||
|
||||
def closed(self):
|
||||
if hasattr(self, 'client'):
|
||||
log.msg('closing local forwarding channel %s' % self.id)
|
||||
self.client.transport.loseConnection()
|
||||
del self.client
|
||||
|
||||
class SSHListenClientForwardingChannel(SSHListenForwardingChannel):
|
||||
|
||||
name = 'direct-tcpip'
|
||||
|
||||
class SSHListenServerForwardingChannel(SSHListenForwardingChannel):
|
||||
|
||||
name = 'forwarded-tcpip'
|
||||
|
||||
|
||||
|
||||
class SSHConnectForwardingChannel(channel.SSHChannel):
|
||||
"""
|
||||
Channel used for handling server side forwarding request.
|
||||
It acts as a client for the remote forwarding destination.
|
||||
|
||||
@ivar hostport: C{(host, port)} requested by client as forwarding
|
||||
destination.
|
||||
@type hostport: C{tupple} or a C{sequence}
|
||||
|
||||
@ivar client: Protocol connected to the forwarding destination.
|
||||
@type client: L{protocol.Protocol}
|
||||
|
||||
@ivar clientBuf: Data received while forwarding channel is not yet
|
||||
connected.
|
||||
@type clientBuf: C{bytes}
|
||||
|
||||
@var _reactor: Reactor used for TCP connections.
|
||||
@type _reactor: A reactor.
|
||||
|
||||
@ivar _channelOpenDeferred: Deferred used in testing to check the
|
||||
result of C{channelOpen}.
|
||||
@type _channelOpenDeferred: L{twisted.internet.defer.Deferred}
|
||||
"""
|
||||
_reactor = reactor
|
||||
|
||||
def __init__(self, hostport, *args, **kw):
|
||||
channel.SSHChannel.__init__(self, *args, **kw)
|
||||
self.hostport = hostport
|
||||
self.client = None
|
||||
self.clientBuf = ''
|
||||
|
||||
|
||||
def channelOpen(self, specificData):
|
||||
"""
|
||||
See: L{channel.SSHChannel}
|
||||
"""
|
||||
log.msg("connecting to %s:%i" % self.hostport)
|
||||
ep = HostnameEndpoint(
|
||||
self._reactor, self.hostport[0], self.hostport[1])
|
||||
d = connectProtocol(ep, SSHForwardingClient(self))
|
||||
d.addCallbacks(self._setClient, self._close)
|
||||
self._channelOpenDeferred = d
|
||||
|
||||
def _setClient(self, client):
|
||||
"""
|
||||
Called when the connection was established to the forwarding
|
||||
destination.
|
||||
|
||||
@param client: Client protocol connected to the forwarding destination.
|
||||
@type client: L{protocol.Protocol}
|
||||
"""
|
||||
self.client = client
|
||||
log.msg("connected to %s:%i" % self.hostport)
|
||||
if self.clientBuf:
|
||||
self.client.transport.write(self.clientBuf)
|
||||
self.clientBuf = None
|
||||
if self.client.buf[1:]:
|
||||
self.write(self.client.buf[1:])
|
||||
self.client.buf = ''
|
||||
|
||||
|
||||
def _close(self, reason):
|
||||
"""
|
||||
Called when failed to connect to the forwarding destination.
|
||||
|
||||
@param reason: Reason why connection failed.
|
||||
@type reason: L{twisted.python.failure.Failure}
|
||||
"""
|
||||
log.msg("failed to connect: %s" % reason)
|
||||
self.loseConnection()
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
See: L{channel.SSHChannel}
|
||||
"""
|
||||
if self.client:
|
||||
self.client.transport.write(data)
|
||||
else:
|
||||
self.clientBuf += data
|
||||
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
See: L{channel.SSHChannel}
|
||||
"""
|
||||
if self.client:
|
||||
log.msg('closed remote forwarding channel %s' % self.id)
|
||||
if self.client.channel:
|
||||
self.loseConnection()
|
||||
self.client.transport.loseConnection()
|
||||
del self.client
|
||||
|
||||
|
||||
|
||||
def openConnectForwardingClient(remoteWindow, remoteMaxPacket, data, avatar):
|
||||
remoteHP, origHP = unpackOpen_direct_tcpip(data)
|
||||
return SSHConnectForwardingChannel(remoteHP,
|
||||
remoteWindow=remoteWindow,
|
||||
remoteMaxPacket=remoteMaxPacket,
|
||||
avatar=avatar)
|
||||
|
||||
class SSHForwardingClient(protocol.Protocol):
|
||||
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
self.buf = '\000'
|
||||
|
||||
def dataReceived(self, data):
|
||||
if self.buf:
|
||||
self.buf += data
|
||||
else:
|
||||
self.channel.write(data)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if self.channel:
|
||||
self.channel.loseConnection()
|
||||
self.channel = None
|
||||
|
||||
|
||||
def packOpen_direct_tcpip((connHost, connPort), (origHost, origPort)):
|
||||
"""Pack the data suitable for sending in a CHANNEL_OPEN packet.
|
||||
"""
|
||||
conn = common.NS(connHost) + struct.pack('>L', connPort)
|
||||
orig = common.NS(origHost) + struct.pack('>L', origPort)
|
||||
return conn + orig
|
||||
|
||||
packOpen_forwarded_tcpip = packOpen_direct_tcpip
|
||||
|
||||
def unpackOpen_direct_tcpip(data):
|
||||
"""Unpack the data to a usable format.
|
||||
"""
|
||||
connHost, rest = common.getNS(data)
|
||||
connPort = int(struct.unpack('>L', rest[:4])[0])
|
||||
origHost, rest = common.getNS(rest[4:])
|
||||
origPort = int(struct.unpack('>L', rest[:4])[0])
|
||||
return (connHost, connPort), (origHost, origPort)
|
||||
|
||||
unpackOpen_forwarded_tcpip = unpackOpen_direct_tcpip
|
||||
|
||||
def packGlobal_tcpip_forward((host, port)):
|
||||
return common.NS(host) + struct.pack('>L', port)
|
||||
|
||||
def unpackGlobal_tcpip_forward(data):
|
||||
host, rest = common.getNS(data)
|
||||
port = int(struct.unpack('>L', rest[:4])[0])
|
||||
return host, port
|
||||
|
||||
"""This is how the data -> eof -> close stuff /should/ work.
|
||||
|
||||
debug3: channel 1: waiting for connection
|
||||
debug1: channel 1: connected
|
||||
debug1: channel 1: read<=0 rfd 7 len 0
|
||||
debug1: channel 1: read failed
|
||||
debug1: channel 1: close_read
|
||||
debug1: channel 1: input open -> drain
|
||||
debug1: channel 1: ibuf empty
|
||||
debug1: channel 1: send eof
|
||||
debug1: channel 1: input drain -> closed
|
||||
debug1: channel 1: rcvd eof
|
||||
debug1: channel 1: output open -> drain
|
||||
debug1: channel 1: obuf empty
|
||||
debug1: channel 1: close_write
|
||||
debug1: channel 1: output drain -> closed
|
||||
debug1: channel 1: rcvd close
|
||||
debug3: channel 1: will not send data after close
|
||||
debug1: channel 1: send close
|
||||
debug1: channel 1: is dead
|
||||
"""
|
@ -0,0 +1,857 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_keys -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Handling of RSA and DSA keys.
|
||||
|
||||
Maintainer: U{Paul Swartz}
|
||||
"""
|
||||
|
||||
# base library imports
|
||||
import base64
|
||||
import itertools
|
||||
from hashlib import md5, sha1
|
||||
|
||||
# external library imports
|
||||
from Crypto.Cipher import DES3, AES
|
||||
from Crypto.PublicKey import RSA, DSA
|
||||
from Crypto import Util
|
||||
from pyasn1.error import PyAsn1Error
|
||||
from pyasn1.type import univ
|
||||
from pyasn1.codec.ber import decoder as berDecoder
|
||||
from pyasn1.codec.ber import encoder as berEncoder
|
||||
|
||||
# twisted
|
||||
from twisted.python import randbytes
|
||||
|
||||
# sibling imports
|
||||
from twisted.conch.ssh import common, sexpy
|
||||
|
||||
|
||||
|
||||
class BadKeyError(Exception):
|
||||
"""
|
||||
Raised when a key isn't what we expected from it.
|
||||
|
||||
XXX: we really need to check for bad keys
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class EncryptedKeyError(Exception):
|
||||
"""
|
||||
Raised when an encrypted key is presented to fromString/fromFile without
|
||||
a password.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class Key(object):
|
||||
"""
|
||||
An object representing a key. A key can be either a public or
|
||||
private key. A public key can verify a signature; a private key can
|
||||
create or verify a signature. To generate a string that can be stored
|
||||
on disk, use the toString method. If you have a private key, but want
|
||||
the string representation of the public key, use Key.public().toString().
|
||||
|
||||
@ivar keyObject: The C{Crypto.PublicKey.pubkey.pubkey} object that
|
||||
operations are performed with.
|
||||
"""
|
||||
|
||||
def fromFile(Class, filename, type=None, passphrase=None):
|
||||
"""
|
||||
Return a Key object corresponding to the data in filename. type
|
||||
and passphrase function as they do in fromString.
|
||||
"""
|
||||
return Class.fromString(file(filename, 'rb').read(), type, passphrase)
|
||||
fromFile = classmethod(fromFile)
|
||||
|
||||
|
||||
def fromString(Class, data, type=None, passphrase=None):
|
||||
"""
|
||||
Return a Key object corresponding to the string data.
|
||||
type is optionally the type of string, matching a _fromString_*
|
||||
method. Otherwise, the _guessStringType() classmethod will be used
|
||||
to guess a type. If the key is encrypted, passphrase is used as
|
||||
the decryption key.
|
||||
|
||||
@type data: C{str}
|
||||
@type type: C{None}/C{str}
|
||||
@type passphrase: C{None}/C{str}
|
||||
@rtype: C{Key}
|
||||
"""
|
||||
if type is None:
|
||||
type = Class._guessStringType(data)
|
||||
if type is None:
|
||||
raise BadKeyError('cannot guess the type of %r' % data)
|
||||
method = getattr(Class, '_fromString_%s' % type.upper(), None)
|
||||
if method is None:
|
||||
raise BadKeyError('no _fromString method for %s' % type)
|
||||
if method.func_code.co_argcount == 2: # no passphrase
|
||||
if passphrase:
|
||||
raise BadKeyError('key not encrypted')
|
||||
return method(data)
|
||||
else:
|
||||
return method(data, passphrase)
|
||||
fromString = classmethod(fromString)
|
||||
|
||||
|
||||
def _fromString_BLOB(Class, blob):
|
||||
"""
|
||||
Return a public key object corresponding to this public key blob.
|
||||
The format of a RSA public key blob is::
|
||||
string 'ssh-rsa'
|
||||
integer e
|
||||
integer n
|
||||
|
||||
The format of a DSA public key blob is::
|
||||
string 'ssh-dss'
|
||||
integer p
|
||||
integer q
|
||||
integer g
|
||||
integer y
|
||||
|
||||
@type blob: C{str}
|
||||
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
|
||||
@raises BadKeyError: if the key type (the first string) is unknown.
|
||||
"""
|
||||
keyType, rest = common.getNS(blob)
|
||||
if keyType == 'ssh-rsa':
|
||||
e, n, rest = common.getMP(rest, 2)
|
||||
return Class(RSA.construct((n, e)))
|
||||
elif keyType == 'ssh-dss':
|
||||
p, q, g, y, rest = common.getMP(rest, 4)
|
||||
return Class(DSA.construct((y, g, p, q)))
|
||||
else:
|
||||
raise BadKeyError('unknown blob type: %s' % keyType)
|
||||
_fromString_BLOB = classmethod(_fromString_BLOB)
|
||||
|
||||
|
||||
def _fromString_PRIVATE_BLOB(Class, blob):
|
||||
"""
|
||||
Return a private key object corresponding to this private key blob.
|
||||
The blob formats are as follows:
|
||||
|
||||
RSA keys::
|
||||
string 'ssh-rsa'
|
||||
integer n
|
||||
integer e
|
||||
integer d
|
||||
integer u
|
||||
integer p
|
||||
integer q
|
||||
|
||||
DSA keys::
|
||||
string 'ssh-dss'
|
||||
integer p
|
||||
integer q
|
||||
integer g
|
||||
integer y
|
||||
integer x
|
||||
|
||||
@type blob: C{str}
|
||||
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
|
||||
@raises BadKeyError: if the key type (the first string) is unknown.
|
||||
"""
|
||||
keyType, rest = common.getNS(blob)
|
||||
|
||||
if keyType == 'ssh-rsa':
|
||||
n, e, d, u, p, q, rest = common.getMP(rest, 6)
|
||||
rsakey = Class(RSA.construct((n, e, d, p, q, u)))
|
||||
return rsakey
|
||||
elif keyType == 'ssh-dss':
|
||||
p, q, g, y, x, rest = common.getMP(rest, 5)
|
||||
dsakey = Class(DSA.construct((y, g, p, q, x)))
|
||||
return dsakey
|
||||
else:
|
||||
raise BadKeyError('unknown blob type: %s' % keyType)
|
||||
_fromString_PRIVATE_BLOB = classmethod(_fromString_PRIVATE_BLOB)
|
||||
|
||||
|
||||
def _fromString_PUBLIC_OPENSSH(Class, data):
|
||||
"""
|
||||
Return a public key object corresponding to this OpenSSH public key
|
||||
string. The format of an OpenSSH public key string is::
|
||||
<key type> <base64-encoded public key blob>
|
||||
|
||||
@type data: C{str}
|
||||
@return: A {Crypto.PublicKey.pubkey.pubkey} object
|
||||
@raises BadKeyError: if the blob type is unknown.
|
||||
"""
|
||||
blob = base64.decodestring(data.split()[1])
|
||||
return Class._fromString_BLOB(blob)
|
||||
_fromString_PUBLIC_OPENSSH = classmethod(_fromString_PUBLIC_OPENSSH)
|
||||
|
||||
|
||||
def _fromString_PRIVATE_OPENSSH(Class, data, passphrase):
|
||||
"""
|
||||
Return a private key object corresponding to this OpenSSH private key
|
||||
string. If the key is encrypted, passphrase MUST be provided.
|
||||
Providing a passphrase for an unencrypted key is an error.
|
||||
|
||||
The format of an OpenSSH private key string is::
|
||||
-----BEGIN <key type> PRIVATE KEY-----
|
||||
[Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: DES-EDE3-CBC,<initialization value>]
|
||||
<base64-encoded ASN.1 structure>
|
||||
------END <key type> PRIVATE KEY------
|
||||
|
||||
The ASN.1 structure of a RSA key is::
|
||||
(0, n, e, d, p, q)
|
||||
|
||||
The ASN.1 structure of a DSA key is::
|
||||
(0, p, q, g, y, x)
|
||||
|
||||
@type data: C{str}
|
||||
@type passphrase: C{str}
|
||||
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
|
||||
@raises BadKeyError: if
|
||||
* a passphrase is provided for an unencrypted key
|
||||
* the ASN.1 encoding is incorrect
|
||||
@raises EncryptedKeyError: if
|
||||
* a passphrase is not provided for an encrypted key
|
||||
"""
|
||||
lines = data.strip().split('\n')
|
||||
kind = lines[0][11:14]
|
||||
if lines[1].startswith('Proc-Type: 4,ENCRYPTED'): # encrypted key
|
||||
if not passphrase:
|
||||
raise EncryptedKeyError('Passphrase must be provided '
|
||||
'for an encrypted key')
|
||||
|
||||
# Determine cipher and initialization vector
|
||||
try:
|
||||
_, cipher_iv_info = lines[2].split(' ', 1)
|
||||
cipher, ivdata = cipher_iv_info.rstrip().split(',', 1)
|
||||
except ValueError:
|
||||
raise BadKeyError('invalid DEK-info %r' % lines[2])
|
||||
|
||||
if cipher == 'AES-128-CBC':
|
||||
CipherClass = AES
|
||||
keySize = 16
|
||||
if len(ivdata) != 32:
|
||||
raise BadKeyError('AES encrypted key with a bad IV')
|
||||
elif cipher == 'DES-EDE3-CBC':
|
||||
CipherClass = DES3
|
||||
keySize = 24
|
||||
if len(ivdata) != 16:
|
||||
raise BadKeyError('DES encrypted key with a bad IV')
|
||||
else:
|
||||
raise BadKeyError('unknown encryption type %r' % cipher)
|
||||
|
||||
# extract keyData for decoding
|
||||
iv = ''.join([chr(int(ivdata[i:i + 2], 16))
|
||||
for i in range(0, len(ivdata), 2)])
|
||||
ba = md5(passphrase + iv[:8]).digest()
|
||||
bb = md5(ba + passphrase + iv[:8]).digest()
|
||||
decKey = (ba + bb)[:keySize]
|
||||
b64Data = base64.decodestring(''.join(lines[3:-1]))
|
||||
keyData = CipherClass.new(decKey,
|
||||
CipherClass.MODE_CBC,
|
||||
iv).decrypt(b64Data)
|
||||
removeLen = ord(keyData[-1])
|
||||
keyData = keyData[:-removeLen]
|
||||
else:
|
||||
b64Data = ''.join(lines[1:-1])
|
||||
keyData = base64.decodestring(b64Data)
|
||||
|
||||
try:
|
||||
decodedKey = berDecoder.decode(keyData)[0]
|
||||
except PyAsn1Error, e:
|
||||
raise BadKeyError('Failed to decode key (Bad Passphrase?): %s' % e)
|
||||
|
||||
if kind == 'RSA':
|
||||
if len(decodedKey) == 2: # alternate RSA key
|
||||
decodedKey = decodedKey[0]
|
||||
if len(decodedKey) < 6:
|
||||
raise BadKeyError('RSA key failed to decode properly')
|
||||
|
||||
n, e, d, p, q = [long(value) for value in decodedKey[1:6]]
|
||||
if p > q: # make p smaller than q
|
||||
p, q = q, p
|
||||
return Class(RSA.construct((n, e, d, p, q)))
|
||||
elif kind == 'DSA':
|
||||
p, q, g, y, x = [long(value) for value in decodedKey[1: 6]]
|
||||
if len(decodedKey) < 6:
|
||||
raise BadKeyError('DSA key failed to decode properly')
|
||||
return Class(DSA.construct((y, g, p, q, x)))
|
||||
_fromString_PRIVATE_OPENSSH = classmethod(_fromString_PRIVATE_OPENSSH)
|
||||
|
||||
|
||||
def _fromString_PUBLIC_LSH(Class, data):
|
||||
"""
|
||||
Return a public key corresponding to this LSH public key string.
|
||||
The LSH public key string format is::
|
||||
<s-expression: ('public-key', (<key type>, (<name, <value>)+))>
|
||||
|
||||
The names for a RSA (key type 'rsa-pkcs1-sha1') key are: n, e.
|
||||
The names for a DSA (key type 'dsa') key are: y, g, p, q.
|
||||
|
||||
@type data: C{str}
|
||||
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
|
||||
@raises BadKeyError: if the key type is unknown
|
||||
"""
|
||||
sexp = sexpy.parse(base64.decodestring(data[1:-1]))
|
||||
assert sexp[0] == 'public-key'
|
||||
kd = {}
|
||||
for name, data in sexp[1][1:]:
|
||||
kd[name] = common.getMP(common.NS(data))[0]
|
||||
if sexp[1][0] == 'dsa':
|
||||
return Class(DSA.construct((kd['y'], kd['g'], kd['p'], kd['q'])))
|
||||
elif sexp[1][0] == 'rsa-pkcs1-sha1':
|
||||
return Class(RSA.construct((kd['n'], kd['e'])))
|
||||
else:
|
||||
raise BadKeyError('unknown lsh key type %s' % sexp[1][0])
|
||||
_fromString_PUBLIC_LSH = classmethod(_fromString_PUBLIC_LSH)
|
||||
|
||||
|
||||
def _fromString_PRIVATE_LSH(Class, data):
|
||||
"""
|
||||
Return a private key corresponding to this LSH private key string.
|
||||
The LSH private key string format is::
|
||||
<s-expression: ('private-key', (<key type>, (<name>, <value>)+))>
|
||||
|
||||
The names for a RSA (key type 'rsa-pkcs1-sha1') key are: n, e, d, p, q.
|
||||
The names for a DSA (key type 'dsa') key are: y, g, p, q, x.
|
||||
|
||||
@type data: C{str}
|
||||
@return: a {Crypto.PublicKey.pubkey.pubkey} object
|
||||
@raises BadKeyError: if the key type is unknown
|
||||
"""
|
||||
sexp = sexpy.parse(data)
|
||||
assert sexp[0] == 'private-key'
|
||||
kd = {}
|
||||
for name, data in sexp[1][1:]:
|
||||
kd[name] = common.getMP(common.NS(data))[0]
|
||||
if sexp[1][0] == 'dsa':
|
||||
assert len(kd) == 5, len(kd)
|
||||
return Class(DSA.construct((kd['y'], kd['g'], kd['p'],
|
||||
kd['q'], kd['x'])))
|
||||
elif sexp[1][0] == 'rsa-pkcs1':
|
||||
assert len(kd) == 8, len(kd)
|
||||
if kd['p'] > kd['q']: # make p smaller than q
|
||||
kd['p'], kd['q'] = kd['q'], kd['p']
|
||||
return Class(RSA.construct((kd['n'], kd['e'], kd['d'],
|
||||
kd['p'], kd['q'])))
|
||||
else:
|
||||
raise BadKeyError('unknown lsh key type %s' % sexp[1][0])
|
||||
_fromString_PRIVATE_LSH = classmethod(_fromString_PRIVATE_LSH)
|
||||
|
||||
|
||||
def _fromString_AGENTV3(Class, data):
|
||||
"""
|
||||
Return a private key object corresponsing to the Secure Shell Key
|
||||
Agent v3 format.
|
||||
|
||||
The SSH Key Agent v3 format for a RSA key is::
|
||||
string 'ssh-rsa'
|
||||
integer e
|
||||
integer d
|
||||
integer n
|
||||
integer u
|
||||
integer p
|
||||
integer q
|
||||
|
||||
The SSH Key Agent v3 format for a DSA key is::
|
||||
string 'ssh-dss'
|
||||
integer p
|
||||
integer q
|
||||
integer g
|
||||
integer y
|
||||
integer x
|
||||
|
||||
@type data: C{str}
|
||||
@return: a C{Crypto.PublicKey.pubkey.pubkey} object
|
||||
@raises BadKeyError: if the key type (the first string) is unknown
|
||||
"""
|
||||
keyType, data = common.getNS(data)
|
||||
if keyType == 'ssh-dss':
|
||||
p, data = common.getMP(data)
|
||||
q, data = common.getMP(data)
|
||||
g, data = common.getMP(data)
|
||||
y, data = common.getMP(data)
|
||||
x, data = common.getMP(data)
|
||||
return Class(DSA.construct((y, g, p, q, x)))
|
||||
elif keyType == 'ssh-rsa':
|
||||
e, data = common.getMP(data)
|
||||
d, data = common.getMP(data)
|
||||
n, data = common.getMP(data)
|
||||
u, data = common.getMP(data)
|
||||
p, data = common.getMP(data)
|
||||
q, data = common.getMP(data)
|
||||
return Class(RSA.construct((n, e, d, p, q, u)))
|
||||
else:
|
||||
raise BadKeyError("unknown key type %s" % keyType)
|
||||
_fromString_AGENTV3 = classmethod(_fromString_AGENTV3)
|
||||
|
||||
|
||||
def _guessStringType(Class, data):
|
||||
"""
|
||||
Guess the type of key in data. The types map to _fromString_*
|
||||
methods.
|
||||
"""
|
||||
if data.startswith('ssh-'):
|
||||
return 'public_openssh'
|
||||
elif data.startswith('-----BEGIN'):
|
||||
return 'private_openssh'
|
||||
elif data.startswith('{'):
|
||||
return 'public_lsh'
|
||||
elif data.startswith('('):
|
||||
return 'private_lsh'
|
||||
elif data.startswith('\x00\x00\x00\x07ssh-'):
|
||||
ignored, rest = common.getNS(data)
|
||||
count = 0
|
||||
while rest:
|
||||
count += 1
|
||||
ignored, rest = common.getMP(rest)
|
||||
if count > 4:
|
||||
return 'agentv3'
|
||||
else:
|
||||
return 'blob'
|
||||
_guessStringType = classmethod(_guessStringType)
|
||||
|
||||
|
||||
def __init__(self, keyObject):
|
||||
"""
|
||||
Initialize a PublicKey with a C{Crypto.PublicKey.pubkey.pubkey}
|
||||
object.
|
||||
|
||||
@type keyObject: C{Crypto.PublicKey.pubkey.pubkey}
|
||||
"""
|
||||
self.keyObject = keyObject
|
||||
|
||||
|
||||
def __eq__(self, other):
|
||||
"""
|
||||
Return True if other represents an object with the same key.
|
||||
"""
|
||||
if type(self) == type(other):
|
||||
return self.type() == other.type() and self.data() == other.data()
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
|
||||
def __ne__(self, other):
|
||||
"""
|
||||
Return True if other represents anything other than this key.
|
||||
"""
|
||||
result = self.__eq__(other)
|
||||
if result == NotImplemented:
|
||||
return result
|
||||
return not result
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
Return a pretty representation of this object.
|
||||
"""
|
||||
lines = [
|
||||
'<%s %s (%s bits)' % (
|
||||
self.type(),
|
||||
self.isPublic() and 'Public Key' or 'Private Key',
|
||||
self.keyObject.size())]
|
||||
for k, v in sorted(self.data().items()):
|
||||
lines.append('attr %s:' % k)
|
||||
by = common.MP(v)[4:]
|
||||
while by:
|
||||
m = by[:15]
|
||||
by = by[15:]
|
||||
o = ''
|
||||
for c in m:
|
||||
o = o + '%02x:' % ord(c)
|
||||
if len(m) < 15:
|
||||
o = o[:-1]
|
||||
lines.append('\t' + o)
|
||||
lines[-1] = lines[-1] + '>'
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
def isPublic(self):
|
||||
"""
|
||||
Returns True if this Key is a public key.
|
||||
"""
|
||||
return not self.keyObject.has_private()
|
||||
|
||||
|
||||
def public(self):
|
||||
"""
|
||||
Returns a version of this key containing only the public key data.
|
||||
If this is a public key, this may or may not be the same object
|
||||
as self.
|
||||
"""
|
||||
return Key(self.keyObject.publickey())
|
||||
|
||||
|
||||
def fingerprint(self):
|
||||
"""
|
||||
Get the user presentation of the fingerprint of this L{Key}. As
|
||||
described by U{RFC 4716 section
|
||||
4<http://tools.ietf.org/html/rfc4716#section-4>}::
|
||||
|
||||
The fingerprint of a public key consists of the output of the MD5
|
||||
message-digest algorithm [RFC1321]. The input to the algorithm is
|
||||
the public key data as specified by [RFC4253]. (...) The output
|
||||
of the (MD5) algorithm is presented to the user as a sequence of 16
|
||||
octets printed as hexadecimal with lowercase letters and separated
|
||||
by colons.
|
||||
|
||||
@since: 8.2
|
||||
|
||||
@return: the user presentation of this L{Key}'s fingerprint, as a
|
||||
string.
|
||||
|
||||
@rtype: L{str}
|
||||
"""
|
||||
return ':'.join([x.encode('hex') for x in md5(self.blob()).digest()])
|
||||
|
||||
|
||||
def type(self):
|
||||
"""
|
||||
Return the type of the object we wrap. Currently this can only be
|
||||
'RSA' or 'DSA'.
|
||||
"""
|
||||
# the class is Crypto.PublicKey.<type>.<stuff we don't care about>
|
||||
mod = self.keyObject.__class__.__module__
|
||||
if mod.startswith('Crypto.PublicKey'):
|
||||
type = mod.split('.')[2]
|
||||
else:
|
||||
raise RuntimeError('unknown type of object: %r' % self.keyObject)
|
||||
if type in ('RSA', 'DSA'):
|
||||
return type
|
||||
else:
|
||||
raise RuntimeError('unknown type of key: %s' % type)
|
||||
|
||||
|
||||
def sshType(self):
|
||||
"""
|
||||
Return the type of the object we wrap as defined in the ssh protocol.
|
||||
Currently this can only be 'ssh-rsa' or 'ssh-dss'.
|
||||
"""
|
||||
return {'RSA': 'ssh-rsa', 'DSA': 'ssh-dss'}[self.type()]
|
||||
|
||||
|
||||
def data(self):
|
||||
"""
|
||||
Return the values of the public key as a dictionary.
|
||||
|
||||
@rtype: C{dict}
|
||||
"""
|
||||
keyData = {}
|
||||
for name in self.keyObject.keydata:
|
||||
value = getattr(self.keyObject, name, None)
|
||||
if value is not None:
|
||||
keyData[name] = value
|
||||
return keyData
|
||||
|
||||
|
||||
def blob(self):
|
||||
"""
|
||||
Return the public key blob for this key. The blob is the
|
||||
over-the-wire format for public keys:
|
||||
|
||||
RSA keys::
|
||||
string 'ssh-rsa'
|
||||
integer e
|
||||
integer n
|
||||
|
||||
DSA keys::
|
||||
string 'ssh-dss'
|
||||
integer p
|
||||
integer q
|
||||
integer g
|
||||
integer y
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
type = self.type()
|
||||
data = self.data()
|
||||
if type == 'RSA':
|
||||
return (common.NS('ssh-rsa') + common.MP(data['e']) +
|
||||
common.MP(data['n']))
|
||||
elif type == 'DSA':
|
||||
return (common.NS('ssh-dss') + common.MP(data['p']) +
|
||||
common.MP(data['q']) + common.MP(data['g']) +
|
||||
common.MP(data['y']))
|
||||
|
||||
|
||||
def privateBlob(self):
|
||||
"""
|
||||
Return the private key blob for this key. The blob is the
|
||||
over-the-wire format for private keys:
|
||||
|
||||
RSA keys::
|
||||
string 'ssh-rsa'
|
||||
integer n
|
||||
integer e
|
||||
integer d
|
||||
integer u
|
||||
integer p
|
||||
integer q
|
||||
|
||||
DSA keys::
|
||||
string 'ssh-dss'
|
||||
integer p
|
||||
integer q
|
||||
integer g
|
||||
integer y
|
||||
integer x
|
||||
"""
|
||||
type = self.type()
|
||||
data = self.data()
|
||||
if type == 'RSA':
|
||||
return (common.NS('ssh-rsa') + common.MP(data['n']) +
|
||||
common.MP(data['e']) + common.MP(data['d']) +
|
||||
common.MP(data['u']) + common.MP(data['p']) +
|
||||
common.MP(data['q']))
|
||||
elif type == 'DSA':
|
||||
return (common.NS('ssh-dss') + common.MP(data['p']) +
|
||||
common.MP(data['q']) + common.MP(data['g']) +
|
||||
common.MP(data['y']) + common.MP(data['x']))
|
||||
|
||||
|
||||
def toString(self, type, extra=None):
|
||||
"""
|
||||
Create a string representation of this key. If the key is a private
|
||||
key and you want the represenation of its public key, use
|
||||
C{key.public().toString()}. type maps to a _toString_* method.
|
||||
|
||||
@param type: The type of string to emit. Currently supported values
|
||||
are C{'OPENSSH'}, C{'LSH'}, and C{'AGENTV3'}.
|
||||
@type type: L{str}
|
||||
|
||||
@param extra: Any extra data supported by the selected format which
|
||||
is not part of the key itself. For public OpenSSH keys, this is
|
||||
a comment. For private OpenSSH keys, this is a passphrase to
|
||||
encrypt with.
|
||||
@type extra: L{str} or L{NoneType}
|
||||
|
||||
@rtype: L{str}
|
||||
"""
|
||||
method = getattr(self, '_toString_%s' % type.upper(), None)
|
||||
if method is None:
|
||||
raise BadKeyError('unknown type: %s' % type)
|
||||
if method.func_code.co_argcount == 2:
|
||||
return method(extra)
|
||||
else:
|
||||
return method()
|
||||
|
||||
|
||||
def _toString_OPENSSH(self, extra):
|
||||
"""
|
||||
Return a public or private OpenSSH string. See
|
||||
_fromString_PUBLIC_OPENSSH and _fromString_PRIVATE_OPENSSH for the
|
||||
string formats. If extra is present, it represents a comment for a
|
||||
public key, or a passphrase for a private key.
|
||||
|
||||
@param extra: Comment for a public key or passphrase for a
|
||||
private key
|
||||
@type extra: C{str}
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
data = self.data()
|
||||
if self.isPublic():
|
||||
b64Data = base64.encodestring(self.blob()).replace('\n', '')
|
||||
if not extra:
|
||||
extra = ''
|
||||
return ('%s %s %s' % (self.sshType(), b64Data, extra)).strip()
|
||||
else:
|
||||
lines = ['-----BEGIN %s PRIVATE KEY-----' % self.type()]
|
||||
if self.type() == 'RSA':
|
||||
p, q = data['p'], data['q']
|
||||
objData = (0, data['n'], data['e'], data['d'], q, p,
|
||||
data['d'] % (q - 1), data['d'] % (p - 1),
|
||||
data['u'])
|
||||
else:
|
||||
objData = (0, data['p'], data['q'], data['g'], data['y'],
|
||||
data['x'])
|
||||
asn1Sequence = univ.Sequence()
|
||||
for index, value in itertools.izip(itertools.count(), objData):
|
||||
asn1Sequence.setComponentByPosition(index, univ.Integer(value))
|
||||
asn1Data = berEncoder.encode(asn1Sequence)
|
||||
if extra:
|
||||
iv = randbytes.secureRandom(8)
|
||||
hexiv = ''.join(['%02X' % ord(x) for x in iv])
|
||||
lines.append('Proc-Type: 4,ENCRYPTED')
|
||||
lines.append('DEK-Info: DES-EDE3-CBC,%s\n' % hexiv)
|
||||
ba = md5(extra + iv).digest()
|
||||
bb = md5(ba + extra + iv).digest()
|
||||
encKey = (ba + bb)[:24]
|
||||
padLen = 8 - (len(asn1Data) % 8)
|
||||
asn1Data += (chr(padLen) * padLen)
|
||||
asn1Data = DES3.new(encKey, DES3.MODE_CBC,
|
||||
iv).encrypt(asn1Data)
|
||||
b64Data = base64.encodestring(asn1Data).replace('\n', '')
|
||||
lines += [b64Data[i:i + 64] for i in range(0, len(b64Data), 64)]
|
||||
lines.append('-----END %s PRIVATE KEY-----' % self.type())
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
def _toString_LSH(self):
|
||||
"""
|
||||
Return a public or private LSH key. See _fromString_PUBLIC_LSH and
|
||||
_fromString_PRIVATE_LSH for the key formats.
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
data = self.data()
|
||||
if self.isPublic():
|
||||
if self.type() == 'RSA':
|
||||
keyData = sexpy.pack([['public-key',
|
||||
['rsa-pkcs1-sha1',
|
||||
['n', common.MP(data['n'])[4:]],
|
||||
['e', common.MP(data['e'])[4:]]]]])
|
||||
elif self.type() == 'DSA':
|
||||
keyData = sexpy.pack([['public-key',
|
||||
['dsa',
|
||||
['p', common.MP(data['p'])[4:]],
|
||||
['q', common.MP(data['q'])[4:]],
|
||||
['g', common.MP(data['g'])[4:]],
|
||||
['y', common.MP(data['y'])[4:]]]]])
|
||||
return '{' + base64.encodestring(keyData).replace('\n', '') + '}'
|
||||
else:
|
||||
if self.type() == 'RSA':
|
||||
p, q = data['p'], data['q']
|
||||
return sexpy.pack([['private-key',
|
||||
['rsa-pkcs1',
|
||||
['n', common.MP(data['n'])[4:]],
|
||||
['e', common.MP(data['e'])[4:]],
|
||||
['d', common.MP(data['d'])[4:]],
|
||||
['p', common.MP(q)[4:]],
|
||||
['q', common.MP(p)[4:]],
|
||||
['a', common.MP(data['d'] % (q - 1))[4:]],
|
||||
['b', common.MP(data['d'] % (p - 1))[4:]],
|
||||
['c', common.MP(data['u'])[4:]]]]])
|
||||
elif self.type() == 'DSA':
|
||||
return sexpy.pack([['private-key',
|
||||
['dsa',
|
||||
['p', common.MP(data['p'])[4:]],
|
||||
['q', common.MP(data['q'])[4:]],
|
||||
['g', common.MP(data['g'])[4:]],
|
||||
['y', common.MP(data['y'])[4:]],
|
||||
['x', common.MP(data['x'])[4:]]]]])
|
||||
|
||||
|
||||
def _toString_AGENTV3(self):
|
||||
"""
|
||||
Return a private Secure Shell Agent v3 key. See
|
||||
_fromString_AGENTV3 for the key format.
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
data = self.data()
|
||||
if not self.isPublic():
|
||||
if self.type() == 'RSA':
|
||||
values = (data['e'], data['d'], data['n'], data['u'],
|
||||
data['p'], data['q'])
|
||||
elif self.type() == 'DSA':
|
||||
values = (data['p'], data['q'], data['g'], data['y'],
|
||||
data['x'])
|
||||
return common.NS(self.sshType()) + ''.join(map(common.MP, values))
|
||||
|
||||
|
||||
def sign(self, data):
|
||||
"""
|
||||
Returns a signature with this Key.
|
||||
|
||||
@type data: C{str}
|
||||
@rtype: C{str}
|
||||
"""
|
||||
if self.type() == 'RSA':
|
||||
digest = pkcs1Digest(data, self.keyObject.size() / 8)
|
||||
signature = self.keyObject.sign(digest, '')[0]
|
||||
ret = common.NS(Util.number.long_to_bytes(signature))
|
||||
elif self.type() == 'DSA':
|
||||
digest = sha1(data).digest()
|
||||
randomBytes = randbytes.secureRandom(19)
|
||||
sig = self.keyObject.sign(digest, randomBytes)
|
||||
# SSH insists that the DSS signature blob be two 160-bit integers
|
||||
# concatenated together. The sig[0], [1] numbers from obj.sign
|
||||
# are just numbers, and could be any length from 0 to 160 bits.
|
||||
# Make sure they are padded out to 160 bits (20 bytes each)
|
||||
ret = common.NS(Util.number.long_to_bytes(sig[0], 20) +
|
||||
Util.number.long_to_bytes(sig[1], 20))
|
||||
return common.NS(self.sshType()) + ret
|
||||
|
||||
|
||||
def verify(self, signature, data):
|
||||
"""
|
||||
Returns true if the signature for data is valid for this Key.
|
||||
|
||||
@type signature: C{str}
|
||||
@type data: C{str}
|
||||
@rtype: C{bool}
|
||||
"""
|
||||
if len(signature) == 40:
|
||||
# DSA key with no padding
|
||||
signatureType, signature = 'ssh-dss', common.NS(signature)
|
||||
else:
|
||||
signatureType, signature = common.getNS(signature)
|
||||
if signatureType != self.sshType():
|
||||
return False
|
||||
if self.type() == 'RSA':
|
||||
numbers = common.getMP(signature)
|
||||
digest = pkcs1Digest(data, self.keyObject.size() / 8)
|
||||
elif self.type() == 'DSA':
|
||||
signature = common.getNS(signature)[0]
|
||||
numbers = [Util.number.bytes_to_long(n) for n in signature[:20],
|
||||
signature[20:]]
|
||||
digest = sha1(data).digest()
|
||||
return self.keyObject.verify(digest, numbers)
|
||||
|
||||
|
||||
|
||||
def objectType(obj):
|
||||
"""
|
||||
Return the SSH key type corresponding to a
|
||||
C{Crypto.PublicKey.pubkey.pubkey} object.
|
||||
|
||||
@type obj: C{Crypto.PublicKey.pubkey.pubkey}
|
||||
@rtype: C{str}
|
||||
"""
|
||||
keyDataMapping = {
|
||||
('n', 'e', 'd', 'p', 'q'): 'ssh-rsa',
|
||||
('n', 'e', 'd', 'p', 'q', 'u'): 'ssh-rsa',
|
||||
('y', 'g', 'p', 'q', 'x'): 'ssh-dss'
|
||||
}
|
||||
try:
|
||||
return keyDataMapping[tuple(obj.keydata)]
|
||||
except (KeyError, AttributeError):
|
||||
raise BadKeyError("invalid key object", obj)
|
||||
|
||||
|
||||
|
||||
def pkcs1Pad(data, messageLength):
|
||||
"""
|
||||
Pad out data to messageLength according to the PKCS#1 standard.
|
||||
@type data: C{str}
|
||||
@type messageLength: C{int}
|
||||
"""
|
||||
lenPad = messageLength - 2 - len(data)
|
||||
return '\x01' + ('\xff' * lenPad) + '\x00' + data
|
||||
|
||||
|
||||
|
||||
def pkcs1Digest(data, messageLength):
|
||||
"""
|
||||
Create a message digest using the SHA1 hash algorithm according to the
|
||||
PKCS#1 standard.
|
||||
@type data: C{str}
|
||||
@type messageLength: C{str}
|
||||
"""
|
||||
digest = sha1(data).digest()
|
||||
return pkcs1Pad(ID_SHA1 + digest, messageLength)
|
||||
|
||||
|
||||
|
||||
def lenSig(obj):
|
||||
"""
|
||||
Return the length of the signature in bytes for a key object.
|
||||
|
||||
@type obj: C{Crypto.PublicKey.pubkey.pubkey}
|
||||
@rtype: C{long}
|
||||
"""
|
||||
return obj.size() / 8
|
||||
|
||||
|
||||
ID_SHA1 = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
|
@ -0,0 +1,48 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
The parent class for all the SSH services. Currently implemented services
|
||||
are ssh-userauth and ssh-connection.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
|
||||
from twisted.python import log
|
||||
|
||||
class SSHService(log.Logger):
|
||||
name = None # this is the ssh name for the service
|
||||
protocolMessages = {} # these map #'s -> protocol names
|
||||
transport = None # gets set later
|
||||
|
||||
def serviceStarted(self):
|
||||
"""
|
||||
called when the service is active on the transport.
|
||||
"""
|
||||
|
||||
def serviceStopped(self):
|
||||
"""
|
||||
called when the service is stopped, either by the connection ending
|
||||
or by another service being started
|
||||
"""
|
||||
|
||||
def logPrefix(self):
|
||||
return "SSHService %s on %s" % (self.name,
|
||||
self.transport.transport.logPrefix())
|
||||
|
||||
def packetReceived(self, messageNum, packet):
|
||||
"""
|
||||
called when we receive a packet on the transport
|
||||
"""
|
||||
#print self.protocolMessages
|
||||
if messageNum in self.protocolMessages:
|
||||
messageType = self.protocolMessages[messageNum]
|
||||
f = getattr(self,'ssh_%s' % messageType[4:],
|
||||
None)
|
||||
if f is not None:
|
||||
return f(packet)
|
||||
log.msg("couldn't handle %r" % messageNum)
|
||||
log.msg(repr(packet))
|
||||
self.transport.sendUnimplemented()
|
||||
|
349
1_7.http_proxy_server/python/Twisted-15.2.1-source/twisted/conch/ssh/session.py
Executable file
349
1_7.http_proxy_server/python/Twisted-15.2.1-source/twisted/conch/ssh/session.py
Executable file
@ -0,0 +1,349 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_session -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
This module contains the implementation of SSHSession, which (by default)
|
||||
allows access to a shell and a python interpreter over SSH.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
import struct
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import interfaces, protocol
|
||||
from twisted.python import log
|
||||
from twisted.conch.interfaces import ISession
|
||||
from twisted.conch.ssh import common, channel
|
||||
|
||||
|
||||
class SSHSession(channel.SSHChannel):
|
||||
|
||||
name = 'session'
|
||||
def __init__(self, *args, **kw):
|
||||
channel.SSHChannel.__init__(self, *args, **kw)
|
||||
self.buf = ''
|
||||
self.client = None
|
||||
self.session = None
|
||||
|
||||
def request_subsystem(self, data):
|
||||
subsystem, ignored= common.getNS(data)
|
||||
log.msg('asking for subsystem "%s"' % subsystem)
|
||||
client = self.avatar.lookupSubsystem(subsystem, data)
|
||||
if client:
|
||||
pp = SSHSessionProcessProtocol(self)
|
||||
proto = wrapProcessProtocol(pp)
|
||||
client.makeConnection(proto)
|
||||
pp.makeConnection(wrapProtocol(client))
|
||||
self.client = pp
|
||||
return 1
|
||||
else:
|
||||
log.msg('failed to get subsystem')
|
||||
return 0
|
||||
|
||||
def request_shell(self, data):
|
||||
log.msg('getting shell')
|
||||
if not self.session:
|
||||
self.session = ISession(self.avatar)
|
||||
try:
|
||||
pp = SSHSessionProcessProtocol(self)
|
||||
self.session.openShell(pp)
|
||||
except:
|
||||
log.deferr()
|
||||
return 0
|
||||
else:
|
||||
self.client = pp
|
||||
return 1
|
||||
|
||||
def request_exec(self, data):
|
||||
if not self.session:
|
||||
self.session = ISession(self.avatar)
|
||||
f,data = common.getNS(data)
|
||||
log.msg('executing command "%s"' % f)
|
||||
try:
|
||||
pp = SSHSessionProcessProtocol(self)
|
||||
self.session.execCommand(pp, f)
|
||||
except:
|
||||
log.deferr()
|
||||
return 0
|
||||
else:
|
||||
self.client = pp
|
||||
return 1
|
||||
|
||||
def request_pty_req(self, data):
|
||||
if not self.session:
|
||||
self.session = ISession(self.avatar)
|
||||
term, windowSize, modes = parseRequest_pty_req(data)
|
||||
log.msg('pty request: %s %s' % (term, windowSize))
|
||||
try:
|
||||
self.session.getPty(term, windowSize, modes)
|
||||
except:
|
||||
log.err()
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
def request_window_change(self, data):
|
||||
if not self.session:
|
||||
self.session = ISession(self.avatar)
|
||||
winSize = parseRequest_window_change(data)
|
||||
try:
|
||||
self.session.windowChanged(winSize)
|
||||
except:
|
||||
log.msg('error changing window size')
|
||||
log.err()
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
def dataReceived(self, data):
|
||||
if not self.client:
|
||||
#self.conn.sendClose(self)
|
||||
self.buf += data
|
||||
return
|
||||
self.client.transport.write(data)
|
||||
|
||||
def extReceived(self, dataType, data):
|
||||
if dataType == connection.EXTENDED_DATA_STDERR:
|
||||
if self.client and hasattr(self.client.transport, 'writeErr'):
|
||||
self.client.transport.writeErr(data)
|
||||
else:
|
||||
log.msg('weird extended data: %s'%dataType)
|
||||
|
||||
def eofReceived(self):
|
||||
if self.session:
|
||||
self.session.eofReceived()
|
||||
elif self.client:
|
||||
self.conn.sendClose(self)
|
||||
|
||||
def closed(self):
|
||||
if self.session:
|
||||
self.session.closed()
|
||||
elif self.client:
|
||||
self.client.transport.loseConnection()
|
||||
|
||||
#def closeReceived(self):
|
||||
# self.loseConnection() # don't know what to do with this
|
||||
|
||||
def loseConnection(self):
|
||||
if self.client:
|
||||
self.client.transport.loseConnection()
|
||||
channel.SSHChannel.loseConnection(self)
|
||||
|
||||
class _ProtocolWrapper(protocol.ProcessProtocol):
|
||||
"""
|
||||
This class wraps a L{Protocol} instance in a L{ProcessProtocol} instance.
|
||||
"""
|
||||
def __init__(self, proto):
|
||||
self.proto = proto
|
||||
|
||||
def connectionMade(self): self.proto.connectionMade()
|
||||
|
||||
def outReceived(self, data): self.proto.dataReceived(data)
|
||||
|
||||
def processEnded(self, reason): self.proto.connectionLost(reason)
|
||||
|
||||
class _DummyTransport:
|
||||
|
||||
def __init__(self, proto):
|
||||
self.proto = proto
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.proto.transport.write(data)
|
||||
|
||||
def write(self, data):
|
||||
self.proto.dataReceived(data)
|
||||
|
||||
def writeSequence(self, seq):
|
||||
self.write(''.join(seq))
|
||||
|
||||
def loseConnection(self):
|
||||
self.proto.connectionLost(protocol.connectionDone)
|
||||
|
||||
def wrapProcessProtocol(inst):
|
||||
if isinstance(inst, protocol.Protocol):
|
||||
return _ProtocolWrapper(inst)
|
||||
else:
|
||||
return inst
|
||||
|
||||
def wrapProtocol(proto):
|
||||
return _DummyTransport(proto)
|
||||
|
||||
|
||||
|
||||
# SUPPORTED_SIGNALS is a list of signals that every session channel is supposed
|
||||
# to accept. See RFC 4254
|
||||
SUPPORTED_SIGNALS = ["ABRT", "ALRM", "FPE", "HUP", "ILL", "INT", "KILL",
|
||||
"PIPE", "QUIT", "SEGV", "TERM", "USR1", "USR2"]
|
||||
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport)
|
||||
class SSHSessionProcessProtocol(protocol.ProcessProtocol):
|
||||
"""I am both an L{IProcessProtocol} and an L{ITransport}.
|
||||
|
||||
I am a transport to the remote endpoint and a process protocol to the
|
||||
local subsystem.
|
||||
"""
|
||||
|
||||
# once initialized, a dictionary mapping signal values to strings
|
||||
# that follow RFC 4254.
|
||||
_signalValuesToNames = None
|
||||
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
self.lostOutOrErrFlag = False
|
||||
|
||||
def connectionMade(self):
|
||||
if self.session.buf:
|
||||
self.transport.write(self.session.buf)
|
||||
self.session.buf = None
|
||||
|
||||
def outReceived(self, data):
|
||||
self.session.write(data)
|
||||
|
||||
def errReceived(self, err):
|
||||
self.session.writeExtended(connection.EXTENDED_DATA_STDERR, err)
|
||||
|
||||
def outConnectionLost(self):
|
||||
"""
|
||||
EOF should only be sent when both STDOUT and STDERR have been closed.
|
||||
"""
|
||||
if self.lostOutOrErrFlag:
|
||||
self.session.conn.sendEOF(self.session)
|
||||
else:
|
||||
self.lostOutOrErrFlag = True
|
||||
|
||||
def errConnectionLost(self):
|
||||
"""
|
||||
See outConnectionLost().
|
||||
"""
|
||||
self.outConnectionLost()
|
||||
|
||||
def connectionLost(self, reason = None):
|
||||
self.session.loseConnection()
|
||||
|
||||
|
||||
def _getSignalName(self, signum):
|
||||
"""
|
||||
Get a signal name given a signal number.
|
||||
"""
|
||||
if self._signalValuesToNames is None:
|
||||
self._signalValuesToNames = {}
|
||||
# make sure that the POSIX ones are the defaults
|
||||
for signame in SUPPORTED_SIGNALS:
|
||||
signame = 'SIG' + signame
|
||||
sigvalue = getattr(signal, signame, None)
|
||||
if sigvalue is not None:
|
||||
self._signalValuesToNames[sigvalue] = signame
|
||||
for k, v in signal.__dict__.items():
|
||||
# Check for platform specific signals, ignoring Python specific
|
||||
# SIG_DFL and SIG_IGN
|
||||
if k.startswith('SIG') and not k.startswith('SIG_'):
|
||||
if v not in self._signalValuesToNames:
|
||||
self._signalValuesToNames[v] = k + '@' + sys.platform
|
||||
return self._signalValuesToNames[signum]
|
||||
|
||||
|
||||
def processEnded(self, reason=None):
|
||||
"""
|
||||
When we are told the process ended, try to notify the other side about
|
||||
how the process ended using the exit-signal or exit-status requests.
|
||||
Also, close the channel.
|
||||
"""
|
||||
if reason is not None:
|
||||
err = reason.value
|
||||
if err.signal is not None:
|
||||
signame = self._getSignalName(err.signal)
|
||||
if (getattr(os, 'WCOREDUMP', None) is not None and
|
||||
os.WCOREDUMP(err.status)):
|
||||
log.msg('exitSignal: %s (core dumped)' % (signame,))
|
||||
coreDumped = 1
|
||||
else:
|
||||
log.msg('exitSignal: %s' % (signame,))
|
||||
coreDumped = 0
|
||||
self.session.conn.sendRequest(self.session, 'exit-signal',
|
||||
common.NS(signame[3:]) + chr(coreDumped) +
|
||||
common.NS('') + common.NS(''))
|
||||
elif err.exitCode is not None:
|
||||
log.msg('exitCode: %r' % (err.exitCode,))
|
||||
self.session.conn.sendRequest(self.session, 'exit-status',
|
||||
struct.pack('>L', err.exitCode))
|
||||
self.session.loseConnection()
|
||||
|
||||
|
||||
def getHost(self):
|
||||
"""
|
||||
Return the host from my session's transport.
|
||||
"""
|
||||
return self.session.conn.transport.getHost()
|
||||
|
||||
|
||||
def getPeer(self):
|
||||
"""
|
||||
Return the peer from my session's transport.
|
||||
"""
|
||||
return self.session.conn.transport.getPeer()
|
||||
|
||||
|
||||
def write(self, data):
|
||||
self.session.write(data)
|
||||
|
||||
|
||||
def writeSequence(self, seq):
|
||||
self.session.write(''.join(seq))
|
||||
|
||||
|
||||
def loseConnection(self):
|
||||
self.session.loseConnection()
|
||||
|
||||
|
||||
|
||||
class SSHSessionClient(protocol.Protocol):
|
||||
|
||||
def dataReceived(self, data):
|
||||
if self.transport:
|
||||
self.transport.write(data)
|
||||
|
||||
# methods factored out to make live easier on server writers
|
||||
def parseRequest_pty_req(data):
|
||||
"""Parse the data from a pty-req request into usable data.
|
||||
|
||||
@returns: a tuple of (terminal type, (rows, cols, xpixel, ypixel), modes)
|
||||
"""
|
||||
term, rest = common.getNS(data)
|
||||
cols, rows, xpixel, ypixel = struct.unpack('>4L', rest[: 16])
|
||||
modes, ignored= common.getNS(rest[16:])
|
||||
winSize = (rows, cols, xpixel, ypixel)
|
||||
modes = [(ord(modes[i]), struct.unpack('>L', modes[i+1: i+5])[0]) for i in range(0, len(modes)-1, 5)]
|
||||
return term, winSize, modes
|
||||
|
||||
def packRequest_pty_req(term, (rows, cols, xpixel, ypixel), modes):
|
||||
"""Pack a pty-req request so that it is suitable for sending.
|
||||
|
||||
NOTE: modes must be packed before being sent here.
|
||||
"""
|
||||
termPacked = common.NS(term)
|
||||
winSizePacked = struct.pack('>4L', cols, rows, xpixel, ypixel)
|
||||
modesPacked = common.NS(modes) # depend on the client packing modes
|
||||
return termPacked + winSizePacked + modesPacked
|
||||
|
||||
def parseRequest_window_change(data):
|
||||
"""Parse the data from a window-change request into usuable data.
|
||||
|
||||
@returns: a tuple of (rows, cols, xpixel, ypixel)
|
||||
"""
|
||||
cols, rows, xpixel, ypixel = struct.unpack('>4L', data)
|
||||
return rows, cols, xpixel, ypixel
|
||||
|
||||
def packRequest_window_change((rows, cols, xpixel, ypixel)):
|
||||
"""Pack a window-change request so that it is suitable for sending.
|
||||
"""
|
||||
return struct.pack('>4L', cols, rows, xpixel, ypixel)
|
||||
|
||||
import connection
|
@ -0,0 +1,42 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
#
|
||||
|
||||
def parse(s):
|
||||
s = s.strip()
|
||||
expr = []
|
||||
while s:
|
||||
if s[0] == '(':
|
||||
newSexp = []
|
||||
if expr:
|
||||
expr[-1].append(newSexp)
|
||||
expr.append(newSexp)
|
||||
s = s[1:]
|
||||
continue
|
||||
if s[0] == ')':
|
||||
aList = expr.pop()
|
||||
s=s[1:]
|
||||
if not expr:
|
||||
assert not s
|
||||
return aList
|
||||
continue
|
||||
i = 0
|
||||
while s[i].isdigit(): i+=1
|
||||
assert i
|
||||
length = int(s[:i])
|
||||
data = s[i+1:i+1+length]
|
||||
expr[-1].append(data)
|
||||
s=s[i+1+length:]
|
||||
assert 0, "this should not happen"
|
||||
|
||||
def pack(sexp):
|
||||
s = ""
|
||||
for o in sexp:
|
||||
if type(o) in (type(()), type([])):
|
||||
s+='('
|
||||
s+=pack(o)
|
||||
s+=')'
|
||||
else:
|
||||
s+='%i:%s' % (len(o), o)
|
||||
return s
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,838 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_userauth -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation of the ssh-userauth service.
|
||||
Currently implemented authentication types are public-key and password.
|
||||
|
||||
Maintainer: Paul Swartz
|
||||
"""
|
||||
|
||||
import struct
|
||||
from twisted.conch import error, interfaces
|
||||
from twisted.conch.ssh import keys, transport, service
|
||||
from twisted.conch.ssh.common import NS, getNS
|
||||
from twisted.cred import credentials
|
||||
from twisted.cred.error import UnauthorizedLogin
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.python import failure, log
|
||||
|
||||
|
||||
|
||||
class SSHUserAuthServer(service.SSHService):
|
||||
"""
|
||||
A service implementing the server side of the 'ssh-userauth' service. It
|
||||
is used to authenticate the user on the other side as being able to access
|
||||
this server.
|
||||
|
||||
@ivar name: the name of this service: 'ssh-userauth'
|
||||
@type name: C{str}
|
||||
@ivar authenticatedWith: a list of authentication methods that have
|
||||
already been used.
|
||||
@type authenticatedWith: C{list}
|
||||
@ivar loginTimeout: the number of seconds we wait before disconnecting
|
||||
the user for taking too long to authenticate
|
||||
@type loginTimeout: C{int}
|
||||
@ivar attemptsBeforeDisconnect: the number of failed login attempts we
|
||||
allow before disconnecting.
|
||||
@type attemptsBeforeDisconnect: C{int}
|
||||
@ivar loginAttempts: the number of login attempts that have been made
|
||||
@type loginAttempts: C{int}
|
||||
@ivar passwordDelay: the number of seconds to delay when the user gives
|
||||
an incorrect password
|
||||
@type passwordDelay: C{int}
|
||||
@ivar interfaceToMethod: a C{dict} mapping credential interfaces to
|
||||
authentication methods. The server checks to see which of the
|
||||
cred interfaces have checkers and tells the client that those methods
|
||||
are valid for authentication.
|
||||
@type interfaceToMethod: C{dict}
|
||||
@ivar supportedAuthentications: A list of the supported authentication
|
||||
methods.
|
||||
@type supportedAuthentications: C{list} of C{str}
|
||||
@ivar user: the last username the client tried to authenticate with
|
||||
@type user: C{str}
|
||||
@ivar method: the current authentication method
|
||||
@type method: C{str}
|
||||
@ivar nextService: the service the user wants started after authentication
|
||||
has been completed.
|
||||
@type nextService: C{str}
|
||||
@ivar portal: the L{twisted.cred.portal.Portal} we are using for
|
||||
authentication
|
||||
@type portal: L{twisted.cred.portal.Portal}
|
||||
@ivar clock: an object with a callLater method. Stubbed out for testing.
|
||||
"""
|
||||
|
||||
|
||||
name = 'ssh-userauth'
|
||||
loginTimeout = 10 * 60 * 60
|
||||
# 10 minutes before we disconnect them
|
||||
attemptsBeforeDisconnect = 20
|
||||
# 20 login attempts before a disconnect
|
||||
passwordDelay = 1 # number of seconds to delay on a failed password
|
||||
clock = reactor
|
||||
interfaceToMethod = {
|
||||
credentials.ISSHPrivateKey : 'publickey',
|
||||
credentials.IUsernamePassword : 'password',
|
||||
credentials.IPluggableAuthenticationModules : 'keyboard-interactive',
|
||||
}
|
||||
|
||||
|
||||
def serviceStarted(self):
|
||||
"""
|
||||
Called when the userauth service is started. Set up instance
|
||||
variables, check if we should allow password/keyboard-interactive
|
||||
authentication (only allow if the outgoing connection is encrypted) and
|
||||
set up a login timeout.
|
||||
"""
|
||||
self.authenticatedWith = []
|
||||
self.loginAttempts = 0
|
||||
self.user = None
|
||||
self.nextService = None
|
||||
self._pamDeferred = None
|
||||
self.portal = self.transport.factory.portal
|
||||
|
||||
self.supportedAuthentications = []
|
||||
for i in self.portal.listCredentialsInterfaces():
|
||||
if i in self.interfaceToMethod:
|
||||
self.supportedAuthentications.append(self.interfaceToMethod[i])
|
||||
|
||||
if not self.transport.isEncrypted('in'):
|
||||
# don't let us transport password in plaintext
|
||||
if 'password' in self.supportedAuthentications:
|
||||
self.supportedAuthentications.remove('password')
|
||||
if 'keyboard-interactive' in self.supportedAuthentications:
|
||||
self.supportedAuthentications.remove('keyboard-interactive')
|
||||
self._cancelLoginTimeout = self.clock.callLater(
|
||||
self.loginTimeout,
|
||||
self.timeoutAuthentication)
|
||||
|
||||
|
||||
def serviceStopped(self):
|
||||
"""
|
||||
Called when the userauth service is stopped. Cancel the login timeout
|
||||
if it's still going.
|
||||
"""
|
||||
if self._cancelLoginTimeout:
|
||||
self._cancelLoginTimeout.cancel()
|
||||
self._cancelLoginTimeout = None
|
||||
|
||||
|
||||
def timeoutAuthentication(self):
|
||||
"""
|
||||
Called when the user has timed out on authentication. Disconnect
|
||||
with a DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE message.
|
||||
"""
|
||||
self._cancelLoginTimeout = None
|
||||
self.transport.sendDisconnect(
|
||||
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
|
||||
'you took too long')
|
||||
|
||||
|
||||
def tryAuth(self, kind, user, data):
|
||||
"""
|
||||
Try to authenticate the user with the given method. Dispatches to a
|
||||
auth_* method.
|
||||
|
||||
@param kind: the authentication method to try.
|
||||
@type kind: C{str}
|
||||
@param user: the username the client is authenticating with.
|
||||
@type user: C{str}
|
||||
@param data: authentication specific data sent by the client.
|
||||
@type data: C{str}
|
||||
@return: A Deferred called back if the method succeeded, or erred back
|
||||
if it failed.
|
||||
@rtype: C{defer.Deferred}
|
||||
"""
|
||||
log.msg('%s trying auth %s' % (user, kind))
|
||||
if kind not in self.supportedAuthentications:
|
||||
return defer.fail(
|
||||
error.ConchError('unsupported authentication, failing'))
|
||||
kind = kind.replace('-', '_')
|
||||
f = getattr(self,'auth_%s'%kind, None)
|
||||
if f:
|
||||
ret = f(data)
|
||||
if not ret:
|
||||
return defer.fail(
|
||||
error.ConchError('%s return None instead of a Deferred'
|
||||
% kind))
|
||||
else:
|
||||
return ret
|
||||
return defer.fail(error.ConchError('bad auth type: %s' % kind))
|
||||
|
||||
|
||||
def ssh_USERAUTH_REQUEST(self, packet):
|
||||
"""
|
||||
The client has requested authentication. Payload::
|
||||
string user
|
||||
string next service
|
||||
string method
|
||||
<authentication specific data>
|
||||
|
||||
@type packet: C{str}
|
||||
"""
|
||||
user, nextService, method, rest = getNS(packet, 3)
|
||||
if user != self.user or nextService != self.nextService:
|
||||
self.authenticatedWith = [] # clear auth state
|
||||
self.user = user
|
||||
self.nextService = nextService
|
||||
self.method = method
|
||||
d = self.tryAuth(method, user, rest)
|
||||
if not d:
|
||||
self._ebBadAuth(
|
||||
failure.Failure(error.ConchError('auth returned none')))
|
||||
return
|
||||
d.addCallback(self._cbFinishedAuth)
|
||||
d.addErrback(self._ebMaybeBadAuth)
|
||||
d.addErrback(self._ebBadAuth)
|
||||
return d
|
||||
|
||||
|
||||
def _cbFinishedAuth(self, (interface, avatar, logout)):
|
||||
"""
|
||||
The callback when user has successfully been authenticated. For a
|
||||
description of the arguments, see L{twisted.cred.portal.Portal.login}.
|
||||
We start the service requested by the user.
|
||||
"""
|
||||
self.transport.avatar = avatar
|
||||
self.transport.logoutFunction = logout
|
||||
service = self.transport.factory.getService(self.transport,
|
||||
self.nextService)
|
||||
if not service:
|
||||
raise error.ConchError('could not get next service: %s'
|
||||
% self.nextService)
|
||||
log.msg('%s authenticated with %s' % (self.user, self.method))
|
||||
self.transport.sendPacket(MSG_USERAUTH_SUCCESS, '')
|
||||
self.transport.setService(service())
|
||||
|
||||
|
||||
def _ebMaybeBadAuth(self, reason):
|
||||
"""
|
||||
An intermediate errback. If the reason is
|
||||
error.NotEnoughAuthentication, we send a MSG_USERAUTH_FAILURE, but
|
||||
with the partial success indicator set.
|
||||
|
||||
@type reason: L{twisted.python.failure.Failure}
|
||||
"""
|
||||
reason.trap(error.NotEnoughAuthentication)
|
||||
self.transport.sendPacket(MSG_USERAUTH_FAILURE,
|
||||
NS(','.join(self.supportedAuthentications)) + '\xff')
|
||||
|
||||
|
||||
def _ebBadAuth(self, reason):
|
||||
"""
|
||||
The final errback in the authentication chain. If the reason is
|
||||
error.IgnoreAuthentication, we simply return; the authentication
|
||||
method has sent its own response. Otherwise, send a failure message
|
||||
and (if the method is not 'none') increment the number of login
|
||||
attempts.
|
||||
|
||||
@type reason: L{twisted.python.failure.Failure}
|
||||
"""
|
||||
if reason.check(error.IgnoreAuthentication):
|
||||
return
|
||||
if self.method != 'none':
|
||||
log.msg('%s failed auth %s' % (self.user, self.method))
|
||||
if reason.check(UnauthorizedLogin):
|
||||
log.msg('unauthorized login: %s' % reason.getErrorMessage())
|
||||
elif reason.check(error.ConchError):
|
||||
log.msg('reason: %s' % reason.getErrorMessage())
|
||||
else:
|
||||
log.msg(reason.getTraceback())
|
||||
self.loginAttempts += 1
|
||||
if self.loginAttempts > self.attemptsBeforeDisconnect:
|
||||
self.transport.sendDisconnect(
|
||||
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
|
||||
'too many bad auths')
|
||||
return
|
||||
self.transport.sendPacket(
|
||||
MSG_USERAUTH_FAILURE,
|
||||
NS(','.join(self.supportedAuthentications)) + '\x00')
|
||||
|
||||
|
||||
def auth_publickey(self, packet):
|
||||
"""
|
||||
Public key authentication. Payload::
|
||||
byte has signature
|
||||
string algorithm name
|
||||
string key blob
|
||||
[string signature] (if has signature is True)
|
||||
|
||||
Create a SSHPublicKey credential and verify it using our portal.
|
||||
"""
|
||||
hasSig = ord(packet[0])
|
||||
algName, blob, rest = getNS(packet[1:], 2)
|
||||
pubKey = keys.Key.fromString(blob)
|
||||
signature = hasSig and getNS(rest)[0] or None
|
||||
if hasSig:
|
||||
b = (NS(self.transport.sessionID) + chr(MSG_USERAUTH_REQUEST) +
|
||||
NS(self.user) + NS(self.nextService) + NS('publickey') +
|
||||
chr(hasSig) + NS(pubKey.sshType()) + NS(blob))
|
||||
c = credentials.SSHPrivateKey(self.user, algName, blob, b,
|
||||
signature)
|
||||
return self.portal.login(c, None, interfaces.IConchUser)
|
||||
else:
|
||||
c = credentials.SSHPrivateKey(self.user, algName, blob, None, None)
|
||||
return self.portal.login(c, None,
|
||||
interfaces.IConchUser).addErrback(self._ebCheckKey,
|
||||
packet[1:])
|
||||
|
||||
|
||||
def _ebCheckKey(self, reason, packet):
|
||||
"""
|
||||
Called back if the user did not sent a signature. If reason is
|
||||
error.ValidPublicKey then this key is valid for the user to
|
||||
authenticate with. Send MSG_USERAUTH_PK_OK.
|
||||
"""
|
||||
reason.trap(error.ValidPublicKey)
|
||||
# if we make it here, it means that the publickey is valid
|
||||
self.transport.sendPacket(MSG_USERAUTH_PK_OK, packet)
|
||||
return failure.Failure(error.IgnoreAuthentication())
|
||||
|
||||
|
||||
def auth_password(self, packet):
|
||||
"""
|
||||
Password authentication. Payload::
|
||||
string password
|
||||
|
||||
Make a UsernamePassword credential and verify it with our portal.
|
||||
"""
|
||||
password = getNS(packet[1:])[0]
|
||||
c = credentials.UsernamePassword(self.user, password)
|
||||
return self.portal.login(c, None, interfaces.IConchUser).addErrback(
|
||||
self._ebPassword)
|
||||
|
||||
|
||||
def _ebPassword(self, f):
|
||||
"""
|
||||
If the password is invalid, wait before sending the failure in order
|
||||
to delay brute-force password guessing.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
self.clock.callLater(self.passwordDelay, d.callback, f)
|
||||
return d
|
||||
|
||||
|
||||
def auth_keyboard_interactive(self, packet):
|
||||
"""
|
||||
Keyboard interactive authentication. No payload. We create a
|
||||
PluggableAuthenticationModules credential and authenticate with our
|
||||
portal.
|
||||
"""
|
||||
if self._pamDeferred is not None:
|
||||
self.transport.sendDisconnect(
|
||||
transport.DISCONNECT_PROTOCOL_ERROR,
|
||||
"only one keyboard interactive attempt at a time")
|
||||
return defer.fail(error.IgnoreAuthentication())
|
||||
c = credentials.PluggableAuthenticationModules(self.user,
|
||||
self._pamConv)
|
||||
return self.portal.login(c, None, interfaces.IConchUser)
|
||||
|
||||
|
||||
def _pamConv(self, items):
|
||||
"""
|
||||
Convert a list of PAM authentication questions into a
|
||||
MSG_USERAUTH_INFO_REQUEST. Returns a Deferred that will be called
|
||||
back when the user has responses to the questions.
|
||||
|
||||
@param items: a list of 2-tuples (message, kind). We only care about
|
||||
kinds 1 (password) and 2 (text).
|
||||
@type items: C{list}
|
||||
@rtype: L{defer.Deferred}
|
||||
"""
|
||||
resp = []
|
||||
for message, kind in items:
|
||||
if kind == 1: # password
|
||||
resp.append((message, 0))
|
||||
elif kind == 2: # text
|
||||
resp.append((message, 1))
|
||||
elif kind in (3, 4):
|
||||
return defer.fail(error.ConchError(
|
||||
'cannot handle PAM 3 or 4 messages'))
|
||||
else:
|
||||
return defer.fail(error.ConchError(
|
||||
'bad PAM auth kind %i' % kind))
|
||||
packet = NS('') + NS('') + NS('')
|
||||
packet += struct.pack('>L', len(resp))
|
||||
for prompt, echo in resp:
|
||||
packet += NS(prompt)
|
||||
packet += chr(echo)
|
||||
self.transport.sendPacket(MSG_USERAUTH_INFO_REQUEST, packet)
|
||||
self._pamDeferred = defer.Deferred()
|
||||
return self._pamDeferred
|
||||
|
||||
|
||||
def ssh_USERAUTH_INFO_RESPONSE(self, packet):
|
||||
"""
|
||||
The user has responded with answers to PAMs authentication questions.
|
||||
Parse the packet into a PAM response and callback self._pamDeferred.
|
||||
Payload::
|
||||
uint32 numer of responses
|
||||
string response 1
|
||||
...
|
||||
string response n
|
||||
"""
|
||||
d, self._pamDeferred = self._pamDeferred, None
|
||||
|
||||
try:
|
||||
resp = []
|
||||
numResps = struct.unpack('>L', packet[:4])[0]
|
||||
packet = packet[4:]
|
||||
while len(resp) < numResps:
|
||||
response, packet = getNS(packet)
|
||||
resp.append((response, 0))
|
||||
if packet:
|
||||
raise error.ConchError("%i bytes of extra data" % len(packet))
|
||||
except:
|
||||
d.errback(failure.Failure())
|
||||
else:
|
||||
d.callback(resp)
|
||||
|
||||
|
||||
|
||||
class SSHUserAuthClient(service.SSHService):
|
||||
"""
|
||||
A service implementing the client side of 'ssh-userauth'.
|
||||
|
||||
This service will try all authentication methods provided by the server,
|
||||
making callbacks for more information when necessary.
|
||||
|
||||
@ivar name: the name of this service: 'ssh-userauth'
|
||||
@type name: C{str}
|
||||
@ivar preferredOrder: a list of authentication methods that should be used
|
||||
first, in order of preference, if supported by the server
|
||||
@type preferredOrder: C{list}
|
||||
@ivar user: the name of the user to authenticate as
|
||||
@type user: C{str}
|
||||
@ivar instance: the service to start after authentication has finished
|
||||
@type instance: L{service.SSHService}
|
||||
@ivar authenticatedWith: a list of strings of authentication methods we've tried
|
||||
@type authenticatedWith: C{list} of C{str}
|
||||
@ivar triedPublicKeys: a list of public key objects that we've tried to
|
||||
authenticate with
|
||||
@type triedPublicKeys: C{list} of L{Key}
|
||||
@ivar lastPublicKey: the last public key object we've tried to authenticate
|
||||
with
|
||||
@type lastPublicKey: L{Key}
|
||||
"""
|
||||
|
||||
|
||||
name = 'ssh-userauth'
|
||||
preferredOrder = ['publickey', 'password', 'keyboard-interactive']
|
||||
|
||||
|
||||
def __init__(self, user, instance):
|
||||
self.user = user
|
||||
self.instance = instance
|
||||
|
||||
|
||||
def serviceStarted(self):
|
||||
self.authenticatedWith = []
|
||||
self.triedPublicKeys = []
|
||||
self.lastPublicKey = None
|
||||
self.askForAuth('none', '')
|
||||
|
||||
|
||||
def askForAuth(self, kind, extraData):
|
||||
"""
|
||||
Send a MSG_USERAUTH_REQUEST.
|
||||
|
||||
@param kind: the authentication method to try.
|
||||
@type kind: C{str}
|
||||
@param extraData: method-specific data to go in the packet
|
||||
@type extraData: C{str}
|
||||
"""
|
||||
self.lastAuth = kind
|
||||
self.transport.sendPacket(MSG_USERAUTH_REQUEST, NS(self.user) +
|
||||
NS(self.instance.name) + NS(kind) + extraData)
|
||||
|
||||
|
||||
def tryAuth(self, kind):
|
||||
"""
|
||||
Dispatch to an authentication method.
|
||||
|
||||
@param kind: the authentication method
|
||||
@type kind: C{str}
|
||||
"""
|
||||
kind = kind.replace('-', '_')
|
||||
log.msg('trying to auth with %s' % (kind,))
|
||||
f = getattr(self,'auth_%s' % (kind,), None)
|
||||
if f:
|
||||
return f()
|
||||
|
||||
|
||||
def _ebAuth(self, ignored, *args):
|
||||
"""
|
||||
Generic callback for a failed authentication attempt. Respond by
|
||||
asking for the list of accepted methods (the 'none' method)
|
||||
"""
|
||||
self.askForAuth('none', '')
|
||||
|
||||
|
||||
def ssh_USERAUTH_SUCCESS(self, packet):
|
||||
"""
|
||||
We received a MSG_USERAUTH_SUCCESS. The server has accepted our
|
||||
authentication, so start the next service.
|
||||
"""
|
||||
self.transport.setService(self.instance)
|
||||
|
||||
|
||||
def ssh_USERAUTH_FAILURE(self, packet):
|
||||
"""
|
||||
We received a MSG_USERAUTH_FAILURE. Payload::
|
||||
string methods
|
||||
byte partial success
|
||||
|
||||
If partial success is C{True}, then the previous method succeeded but is
|
||||
not sufficient for authentication. C{methods} is a comma-separated list
|
||||
of accepted authentication methods.
|
||||
|
||||
We sort the list of methods by their position in C{self.preferredOrder},
|
||||
removing methods that have already succeeded. We then call
|
||||
C{self.tryAuth} with the most preferred method.
|
||||
|
||||
@param packet: the L{MSG_USERAUTH_FAILURE} payload.
|
||||
@type packet: C{str}
|
||||
|
||||
@return: a L{defer.Deferred} that will be callbacked with C{None} as
|
||||
soon as all authentication methods have been tried, or C{None} if no
|
||||
more authentication methods are available.
|
||||
@rtype: C{defer.Deferred} or C{None}
|
||||
"""
|
||||
canContinue, partial = getNS(packet)
|
||||
partial = ord(partial)
|
||||
if partial:
|
||||
self.authenticatedWith.append(self.lastAuth)
|
||||
|
||||
def orderByPreference(meth):
|
||||
"""
|
||||
Invoked once per authentication method in order to extract a
|
||||
comparison key which is then used for sorting.
|
||||
|
||||
@param meth: the authentication method.
|
||||
@type meth: C{str}
|
||||
|
||||
@return: the comparison key for C{meth}.
|
||||
@rtype: C{int}
|
||||
"""
|
||||
if meth in self.preferredOrder:
|
||||
return self.preferredOrder.index(meth)
|
||||
else:
|
||||
# put the element at the end of the list.
|
||||
return len(self.preferredOrder)
|
||||
|
||||
canContinue = sorted([meth for meth in canContinue.split(',')
|
||||
if meth not in self.authenticatedWith],
|
||||
key=orderByPreference)
|
||||
|
||||
log.msg('can continue with: %s' % canContinue)
|
||||
return self._cbUserauthFailure(None, iter(canContinue))
|
||||
|
||||
|
||||
def _cbUserauthFailure(self, result, iterator):
|
||||
if result:
|
||||
return
|
||||
try:
|
||||
method = iterator.next()
|
||||
except StopIteration:
|
||||
self.transport.sendDisconnect(
|
||||
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
|
||||
'no more authentication methods available')
|
||||
else:
|
||||
d = defer.maybeDeferred(self.tryAuth, method)
|
||||
d.addCallback(self._cbUserauthFailure, iterator)
|
||||
return d
|
||||
|
||||
|
||||
def ssh_USERAUTH_PK_OK(self, packet):
|
||||
"""
|
||||
This message (number 60) can mean several different messages depending
|
||||
on the current authentication type. We dispatch to individual methods
|
||||
in order to handle this request.
|
||||
"""
|
||||
func = getattr(self, 'ssh_USERAUTH_PK_OK_%s' %
|
||||
self.lastAuth.replace('-', '_'), None)
|
||||
if func is not None:
|
||||
return func(packet)
|
||||
else:
|
||||
self.askForAuth('none', '')
|
||||
|
||||
|
||||
def ssh_USERAUTH_PK_OK_publickey(self, packet):
|
||||
"""
|
||||
This is MSG_USERAUTH_PK. Our public key is valid, so we create a
|
||||
signature and try to authenticate with it.
|
||||
"""
|
||||
publicKey = self.lastPublicKey
|
||||
b = (NS(self.transport.sessionID) + chr(MSG_USERAUTH_REQUEST) +
|
||||
NS(self.user) + NS(self.instance.name) + NS('publickey') +
|
||||
'\x01' + NS(publicKey.sshType()) + NS(publicKey.blob()))
|
||||
d = self.signData(publicKey, b)
|
||||
if not d:
|
||||
self.askForAuth('none', '')
|
||||
# this will fail, we'll move on
|
||||
return
|
||||
d.addCallback(self._cbSignedData)
|
||||
d.addErrback(self._ebAuth)
|
||||
|
||||
|
||||
def ssh_USERAUTH_PK_OK_password(self, packet):
|
||||
"""
|
||||
This is MSG_USERAUTH_PASSWD_CHANGEREQ. The password given has expired.
|
||||
We ask for an old password and a new password, then send both back to
|
||||
the server.
|
||||
"""
|
||||
prompt, language, rest = getNS(packet, 2)
|
||||
self._oldPass = self._newPass = None
|
||||
d = self.getPassword('Old Password: ')
|
||||
d = d.addCallbacks(self._setOldPass, self._ebAuth)
|
||||
d.addCallback(lambda ignored: self.getPassword(prompt))
|
||||
d.addCallbacks(self._setNewPass, self._ebAuth)
|
||||
|
||||
|
||||
def ssh_USERAUTH_PK_OK_keyboard_interactive(self, packet):
|
||||
"""
|
||||
This is MSG_USERAUTH_INFO_RESPONSE. The server has sent us the
|
||||
questions it wants us to answer, so we ask the user and sent the
|
||||
responses.
|
||||
"""
|
||||
name, instruction, lang, data = getNS(packet, 3)
|
||||
numPrompts = struct.unpack('!L', data[:4])[0]
|
||||
data = data[4:]
|
||||
prompts = []
|
||||
for i in range(numPrompts):
|
||||
prompt, data = getNS(data)
|
||||
echo = bool(ord(data[0]))
|
||||
data = data[1:]
|
||||
prompts.append((prompt, echo))
|
||||
d = self.getGenericAnswers(name, instruction, prompts)
|
||||
d.addCallback(self._cbGenericAnswers)
|
||||
d.addErrback(self._ebAuth)
|
||||
|
||||
|
||||
def _cbSignedData(self, signedData):
|
||||
"""
|
||||
Called back out of self.signData with the signed data. Send the
|
||||
authentication request with the signature.
|
||||
|
||||
@param signedData: the data signed by the user's private key.
|
||||
@type signedData: C{str}
|
||||
"""
|
||||
publicKey = self.lastPublicKey
|
||||
self.askForAuth('publickey', '\x01' + NS(publicKey.sshType()) +
|
||||
NS(publicKey.blob()) + NS(signedData))
|
||||
|
||||
|
||||
def _setOldPass(self, op):
|
||||
"""
|
||||
Called back when we are choosing a new password. Simply store the old
|
||||
password for now.
|
||||
|
||||
@param op: the old password as entered by the user
|
||||
@type op: C{str}
|
||||
"""
|
||||
self._oldPass = op
|
||||
|
||||
|
||||
def _setNewPass(self, np):
|
||||
"""
|
||||
Called back when we are choosing a new password. Get the old password
|
||||
and send the authentication message with both.
|
||||
|
||||
@param np: the new password as entered by the user
|
||||
@type np: C{str}
|
||||
"""
|
||||
op = self._oldPass
|
||||
self._oldPass = None
|
||||
self.askForAuth('password', '\xff' + NS(op) + NS(np))
|
||||
|
||||
|
||||
def _cbGenericAnswers(self, responses):
|
||||
"""
|
||||
Called back when we are finished answering keyboard-interactive
|
||||
questions. Send the info back to the server in a
|
||||
MSG_USERAUTH_INFO_RESPONSE.
|
||||
|
||||
@param responses: a list of C{str} responses
|
||||
@type responses: C{list}
|
||||
"""
|
||||
data = struct.pack('!L', len(responses))
|
||||
for r in responses:
|
||||
data += NS(r.encode('UTF8'))
|
||||
self.transport.sendPacket(MSG_USERAUTH_INFO_RESPONSE, data)
|
||||
|
||||
|
||||
def auth_publickey(self):
|
||||
"""
|
||||
Try to authenticate with a public key. Ask the user for a public key;
|
||||
if the user has one, send the request to the server and return True.
|
||||
Otherwise, return False.
|
||||
|
||||
@rtype: C{bool}
|
||||
"""
|
||||
d = defer.maybeDeferred(self.getPublicKey)
|
||||
d.addBoth(self._cbGetPublicKey)
|
||||
return d
|
||||
|
||||
|
||||
def _cbGetPublicKey(self, publicKey):
|
||||
if not isinstance(publicKey, keys.Key): # failure or None
|
||||
publicKey = None
|
||||
if publicKey is not None:
|
||||
self.lastPublicKey = publicKey
|
||||
self.triedPublicKeys.append(publicKey)
|
||||
log.msg('using key of type %s' % publicKey.type())
|
||||
self.askForAuth('publickey', '\x00' + NS(publicKey.sshType()) +
|
||||
NS(publicKey.blob()))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def auth_password(self):
|
||||
"""
|
||||
Try to authenticate with a password. Ask the user for a password.
|
||||
If the user will return a password, return True. Otherwise, return
|
||||
False.
|
||||
|
||||
@rtype: C{bool}
|
||||
"""
|
||||
d = self.getPassword()
|
||||
if d:
|
||||
d.addCallbacks(self._cbPassword, self._ebAuth)
|
||||
return True
|
||||
else: # returned None, don't do password auth
|
||||
return False
|
||||
|
||||
|
||||
def auth_keyboard_interactive(self):
|
||||
"""
|
||||
Try to authenticate with keyboard-interactive authentication. Send
|
||||
the request to the server and return True.
|
||||
|
||||
@rtype: C{bool}
|
||||
"""
|
||||
log.msg('authing with keyboard-interactive')
|
||||
self.askForAuth('keyboard-interactive', NS('') + NS(''))
|
||||
return True
|
||||
|
||||
|
||||
def _cbPassword(self, password):
|
||||
"""
|
||||
Called back when the user gives a password. Send the request to the
|
||||
server.
|
||||
|
||||
@param password: the password the user entered
|
||||
@type password: C{str}
|
||||
"""
|
||||
self.askForAuth('password', '\x00' + NS(password))
|
||||
|
||||
|
||||
def signData(self, publicKey, signData):
|
||||
"""
|
||||
Sign the given data with the given public key.
|
||||
|
||||
By default, this will call getPrivateKey to get the private key,
|
||||
then sign the data using Key.sign().
|
||||
|
||||
This method is factored out so that it can be overridden to use
|
||||
alternate methods, such as a key agent.
|
||||
|
||||
@param publicKey: The public key object returned from L{getPublicKey}
|
||||
@type publicKey: L{keys.Key}
|
||||
|
||||
@param signData: the data to be signed by the private key.
|
||||
@type signData: C{str}
|
||||
@return: a Deferred that's called back with the signature
|
||||
@rtype: L{defer.Deferred}
|
||||
"""
|
||||
key = self.getPrivateKey()
|
||||
if not key:
|
||||
return
|
||||
return key.addCallback(self._cbSignData, signData)
|
||||
|
||||
|
||||
def _cbSignData(self, privateKey, signData):
|
||||
"""
|
||||
Called back when the private key is returned. Sign the data and
|
||||
return the signature.
|
||||
|
||||
@param privateKey: the private key object
|
||||
@type publicKey: L{keys.Key}
|
||||
@param signData: the data to be signed by the private key.
|
||||
@type signData: C{str}
|
||||
@return: the signature
|
||||
@rtype: C{str}
|
||||
"""
|
||||
return privateKey.sign(signData)
|
||||
|
||||
|
||||
def getPublicKey(self):
|
||||
"""
|
||||
Return a public key for the user. If no more public keys are
|
||||
available, return C{None}.
|
||||
|
||||
This implementation always returns C{None}. Override it in a
|
||||
subclass to actually find and return a public key object.
|
||||
|
||||
@rtype: L{Key} or L{NoneType}
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def getPrivateKey(self):
|
||||
"""
|
||||
Return a L{Deferred} that will be called back with the private key
|
||||
object corresponding to the last public key from getPublicKey().
|
||||
If the private key is not available, errback on the Deferred.
|
||||
|
||||
@rtype: L{Deferred} called back with L{Key}
|
||||
"""
|
||||
return defer.fail(NotImplementedError())
|
||||
|
||||
|
||||
def getPassword(self, prompt = None):
|
||||
"""
|
||||
Return a L{Deferred} that will be called back with a password.
|
||||
prompt is a string to display for the password, or None for a generic
|
||||
'user@hostname's password: '.
|
||||
|
||||
@type prompt: C{str}/C{None}
|
||||
@rtype: L{defer.Deferred}
|
||||
"""
|
||||
return defer.fail(NotImplementedError())
|
||||
|
||||
|
||||
def getGenericAnswers(self, name, instruction, prompts):
|
||||
"""
|
||||
Returns a L{Deferred} with the responses to the promopts.
|
||||
|
||||
@param name: The name of the authentication currently in progress.
|
||||
@param instruction: Describes what the authentication wants.
|
||||
@param prompts: A list of (prompt, echo) pairs, where prompt is a
|
||||
string to display and echo is a boolean indicating whether the
|
||||
user's response should be echoed as they type it.
|
||||
"""
|
||||
return defer.fail(NotImplementedError())
|
||||
|
||||
|
||||
MSG_USERAUTH_REQUEST = 50
|
||||
MSG_USERAUTH_FAILURE = 51
|
||||
MSG_USERAUTH_SUCCESS = 52
|
||||
MSG_USERAUTH_BANNER = 53
|
||||
MSG_USERAUTH_INFO_RESPONSE = 61
|
||||
MSG_USERAUTH_PK_OK = 60
|
||||
|
||||
messages = {}
|
||||
for k, v in locals().items():
|
||||
if k[:4]=='MSG_':
|
||||
messages[v] = k
|
||||
|
||||
SSHUserAuthServer.protocolMessages = messages
|
||||
SSHUserAuthClient.protocolMessages = messages
|
||||
del messages
|
||||
del v
|
||||
|
||||
# Doubles, not included in the protocols' mappings
|
||||
MSG_USERAUTH_PASSWD_CHANGEREQ = 60
|
||||
MSG_USERAUTH_INFO_REQUEST = 60
|
@ -0,0 +1,95 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_manhole -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Asynchronous local terminal input handling
|
||||
|
||||
@author: Jp Calderone
|
||||
"""
|
||||
|
||||
import os, tty, sys, termios
|
||||
|
||||
from twisted.internet import reactor, stdio, protocol, defer
|
||||
from twisted.python import failure, reflect, log
|
||||
|
||||
from twisted.conch.insults.insults import ServerProtocol
|
||||
from twisted.conch.manhole import ColoredManhole
|
||||
|
||||
class UnexpectedOutputError(Exception):
|
||||
pass
|
||||
|
||||
class TerminalProcessProtocol(protocol.ProcessProtocol):
|
||||
def __init__(self, proto):
|
||||
self.proto = proto
|
||||
self.onConnection = defer.Deferred()
|
||||
|
||||
def connectionMade(self):
|
||||
self.proto.makeConnection(self)
|
||||
self.onConnection.callback(None)
|
||||
self.onConnection = None
|
||||
|
||||
def write(self, bytes):
|
||||
self.transport.write(bytes)
|
||||
|
||||
def outReceived(self, bytes):
|
||||
self.proto.dataReceived(bytes)
|
||||
|
||||
def errReceived(self, bytes):
|
||||
self.transport.loseConnection()
|
||||
if self.proto is not None:
|
||||
self.proto.connectionLost(failure.Failure(UnexpectedOutputError(bytes)))
|
||||
self.proto = None
|
||||
|
||||
def childConnectionLost(self, childFD):
|
||||
if self.proto is not None:
|
||||
self.proto.childConnectionLost(childFD)
|
||||
|
||||
def processEnded(self, reason):
|
||||
if self.proto is not None:
|
||||
self.proto.connectionLost(reason)
|
||||
self.proto = None
|
||||
|
||||
|
||||
|
||||
class ConsoleManhole(ColoredManhole):
|
||||
"""
|
||||
A manhole protocol specifically for use with L{stdio.StandardIO}.
|
||||
"""
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
When the connection is lost, there is nothing more to do. Stop the
|
||||
reactor so that the process can exit.
|
||||
"""
|
||||
reactor.stop()
|
||||
|
||||
|
||||
|
||||
def runWithProtocol(klass):
|
||||
fd = sys.__stdin__.fileno()
|
||||
oldSettings = termios.tcgetattr(fd)
|
||||
tty.setraw(fd)
|
||||
try:
|
||||
p = ServerProtocol(klass)
|
||||
stdio.StandardIO(p)
|
||||
reactor.run()
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSANOW, oldSettings)
|
||||
os.write(fd, "\r\x1bc\r")
|
||||
|
||||
|
||||
|
||||
def main(argv=None):
|
||||
log.startLogging(file('child.log', 'w'))
|
||||
|
||||
if argv is None:
|
||||
argv = sys.argv[1:]
|
||||
if argv:
|
||||
klass = reflect.namedClass(argv[0])
|
||||
else:
|
||||
klass = ConsoleManhole
|
||||
runWithProtocol(klass)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -0,0 +1,93 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_tap -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Support module for making SSH servers with twistd.
|
||||
"""
|
||||
|
||||
from twisted.conch import unix
|
||||
from twisted.conch import checkers as conch_checkers
|
||||
from twisted.conch.openssh_compat import factory
|
||||
from twisted.cred import portal, checkers, strcred
|
||||
from twisted.python import usage
|
||||
from twisted.application import strports
|
||||
try:
|
||||
from twisted.cred import pamauth
|
||||
except ImportError:
|
||||
pamauth = None
|
||||
|
||||
|
||||
|
||||
class Options(usage.Options, strcred.AuthOptionMixin):
|
||||
synopsis = "[-i <interface>] [-p <port>] [-d <dir>] "
|
||||
longdesc = ("Makes a Conch SSH server. If no authentication methods are "
|
||||
"specified, the default authentication methods are UNIX passwords, "
|
||||
"SSH public keys, and PAM if it is available. If --auth options are "
|
||||
"passed, only the measures specified will be used.")
|
||||
optParameters = [
|
||||
["interface", "i", "", "local interface to which we listen"],
|
||||
["port", "p", "tcp:22", "Port on which to listen"],
|
||||
["data", "d", "/etc", "directory to look for host keys in"],
|
||||
["moduli", "", None, "directory to look for moduli in "
|
||||
"(if different from --data)"]
|
||||
]
|
||||
compData = usage.Completions(
|
||||
optActions={"data": usage.CompleteDirs(descr="data directory"),
|
||||
"moduli": usage.CompleteDirs(descr="moduli directory"),
|
||||
"interface": usage.CompleteNetInterfaces()}
|
||||
)
|
||||
|
||||
|
||||
def __init__(self, *a, **kw):
|
||||
usage.Options.__init__(self, *a, **kw)
|
||||
|
||||
# call the default addCheckers (for backwards compatibility) that will
|
||||
# be used if no --auth option is provided - note that conch's
|
||||
# UNIXPasswordDatabase is used, instead of twisted.plugins.cred_unix's
|
||||
# checker
|
||||
super(Options, self).addChecker(conch_checkers.UNIXPasswordDatabase())
|
||||
super(Options, self).addChecker(conch_checkers.SSHPublicKeyChecker(
|
||||
conch_checkers.UNIXAuthorizedKeysFiles()))
|
||||
if pamauth is not None:
|
||||
super(Options, self).addChecker(
|
||||
checkers.PluggableAuthenticationModulesChecker())
|
||||
self._usingDefaultAuth = True
|
||||
|
||||
|
||||
def addChecker(self, checker):
|
||||
"""
|
||||
Add the checker specified. If any checkers are added, the default
|
||||
checkers are automatically cleared and the only checkers will be the
|
||||
specified one(s).
|
||||
"""
|
||||
if self._usingDefaultAuth:
|
||||
self['credCheckers'] = []
|
||||
self['credInterfaces'] = {}
|
||||
self._usingDefaultAuth = False
|
||||
super(Options, self).addChecker(checker)
|
||||
|
||||
|
||||
|
||||
def makeService(config):
|
||||
"""
|
||||
Construct a service for operating a SSH server.
|
||||
|
||||
@param config: An L{Options} instance specifying server options, including
|
||||
where server keys are stored and what authentication methods to use.
|
||||
|
||||
@return: An L{IService} provider which contains the requested SSH server.
|
||||
"""
|
||||
|
||||
t = factory.OpenSSHFactory()
|
||||
|
||||
r = unix.UnixSSHRealm()
|
||||
t.portal = portal.Portal(r, config.get('credCheckers', []))
|
||||
t.dataRoot = config['data']
|
||||
t.moduliRoot = config['moduli'] or config['data']
|
||||
|
||||
port = config['port']
|
||||
if config['interface']:
|
||||
# Add warning here
|
||||
port += ':interface=' + config['interface']
|
||||
return strports.service(port, t)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1 @@
|
||||
'conch tests'
|
@ -0,0 +1,208 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_keys -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Data used by test_keys as well as others.
|
||||
"""
|
||||
RSAData = {
|
||||
'n':long('1062486685755247411169438309495398947372127791189432809481'
|
||||
'382072971106157632182084539383569281493520117634129557550415277'
|
||||
'516685881326038852354459895734875625093273594925884531272867425'
|
||||
'864910490065695876046999646807138717162833156501L'),
|
||||
'e':35L,
|
||||
'd':long('6678487739032983727350755088256793383481946116047863373882'
|
||||
'973030104095847973715959961839578340816412167985957218887914482'
|
||||
'713602371850869127033494910375212470664166001439410214474266799'
|
||||
'85974425203903884190893469297150446322896587555L'),
|
||||
'q':long('3395694744258061291019136154000709371890447462086362702627'
|
||||
'9704149412726577280741108645721676968699696898960891593323L'),
|
||||
'p':long('3128922844292337321766351031842562691837301298995834258844'
|
||||
'4720539204069737532863831050930719431498338835415515173887L')}
|
||||
|
||||
DSAData = {
|
||||
'y':long('2300663509295750360093768159135720439490120577534296730713'
|
||||
'348508834878775464483169644934425336771277908527130096489120714'
|
||||
'610188630979820723924744291603865L'),
|
||||
'g':long('4451569990409370769930903934104221766858515498655655091803'
|
||||
'866645719060300558655677517139568505649468378587802312867198352'
|
||||
'1161998270001677664063945776405L'),
|
||||
'p':long('7067311773048598659694590252855127633397024017439939353776'
|
||||
'608320410518694001356789646664502838652272205440894335303988504'
|
||||
'978724817717069039110940675621677L'),
|
||||
'q':1184501645189849666738820838619601267690550087703L,
|
||||
'x':863951293559205482820041244219051653999559962819L}
|
||||
|
||||
publicRSA_openssh = ("ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAGEArzJx8OYOnJmzf4tfBE"
|
||||
"vLi8DVPrJ3/c9k2I/Az64fxjHf9imyRJbixtQhlH9lfNjUIx+4LmrJH5QNRsFporcHDKOTwTTYL"
|
||||
"h5KmRpslkYHRivcJSkbh/C+BR3utDS555mV comment")
|
||||
|
||||
privateRSA_openssh = """-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIByAIBAAJhAK8ycfDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW
|
||||
4sbUIZR/ZXzY1CMfuC5qyR+UDUbBaaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fw
|
||||
vgUd7rQ0ueeZlQIBIwJgbh+1VZfr7WftK5lu7MHtqE1S1vPWZQYE3+VUn8yJADyb
|
||||
Z4fsZaCrzW9lkIqXkE3GIY+ojdhZhkO1gbG0118sIgphwSWKRxK0mvh6ERxKqIt1
|
||||
xJEJO74EykXZV4oNJ8sjAjEA3J9r2ZghVhGN6V8DnQrTk24Td0E8hU8AcP0FVP+8
|
||||
PQm/g/aXf2QQkQT+omdHVEJrAjEAy0pL0EBH6EVS98evDCBtQw22OZT52qXlAwZ2
|
||||
gyTriKFVoqjeEjt3SZKKqXHSApP/AjBLpF99zcJJZRq2abgYlf9lv1chkrWqDHUu
|
||||
DZttmYJeEfiFBBavVYIF1dOlZT0G8jMCMBc7sOSZodFnAiryP+Qg9otSBjJ3bQML
|
||||
pSTqy7c3a2AScC/YyOwkDaICHnnD3XyjMwIxALRzl0tQEKMXs6hH8ToUdlLROCrP
|
||||
EhQ0wahUTCk1gKA4uPD6TMTChavbh4K63OvbKg==
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
# some versions of OpenSSH generate these (slightly different keys)
|
||||
privateRSA_openssh_alternate = """-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIBzjCCAcgCAQACYQCvMnHw5g6cmbN/i18ES8uLwNU+snf9z2TYj8DPrh/GMd/2
|
||||
KbJEluLG1CGUf2V82NQjH7guaskflA1GwWmitwcMo5PBNNguHkqZGmyWRgdGK9wl
|
||||
KRuH8L4FHe60NLnnmZUCASMCYG4ftVWX6+1n7SuZbuzB7ahNUtbz1mUGBN/lVJ/M
|
||||
iQA8m2eH7GWgq81vZZCKl5BNxiGPqI3YWYZDtYGxtNdfLCIKYcElikcStJr4ehEc
|
||||
SqiLdcSRCTu+BMpF2VeKDSfLIwIxANyfa9mYIVYRjelfA50K05NuE3dBPIVPAHD9
|
||||
BVT/vD0Jv4P2l39kEJEE/qJnR1RCawIxAMtKS9BAR+hFUvfHrwwgbUMNtjmU+dql
|
||||
5QMGdoMk64ihVaKo3hI7d0mSiqlx0gKT/wIwS6Rffc3CSWUatmm4GJX/Zb9XIZK1
|
||||
qgx1Lg2bbZmCXhH4hQQWr1WCBdXTpWU9BvIzAjAXO7DkmaHRZwIq8j/kIPaLUgYy
|
||||
d20DC6Uk6su3N2tgEnAv2MjsJA2iAh55w918ozMCMQC0c5dLUBCjF7OoR/E6FHZS
|
||||
0TgqzxIUNMGoVEwpNYCgOLjw+kzEwoWr24eCutzr2yowAA==
|
||||
------END RSA PRIVATE KEY------"""
|
||||
|
||||
# encrypted with the passphrase 'encrypted'
|
||||
privateRSA_openssh_encrypted = """-----BEGIN RSA PRIVATE KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: DES-EDE3-CBC,FFFFFFFFFFFFFFFF
|
||||
|
||||
30qUR7DYY/rpVJu159paRM1mUqt/IMibfEMTKWSjNhCVD21hskftZCJROw/WgIFt
|
||||
ncusHpJMkjgwEpho0KyKilcC7zxjpunTex24Meb5pCdXCrYft8AyUkRdq3dugMqT
|
||||
4nuWuWxziluBhKQ2M9tPGcEOeulU4vVjceZt2pZhZQVBf08o3XUv5/7RYd24M9md
|
||||
WIo+5zdj2YQkI6xMFTP954O/X32ME1KQt98wgNEy6mxhItbvf00mH3woALwEKP3v
|
||||
PSMxxtx3VKeDKd9YTOm1giKkXZUf91vZWs0378tUBrU4U5qJxgryTjvvVKOtofj6
|
||||
4qQy6+r6M6wtwVlXBgeRm2gBPvL3nv6MsROp3E6ztBd/e7A8fSec+UTq3ko/EbGP
|
||||
0QG+IG5tg8FsdITxQ9WAIITZL3Rc6hA5Ymx1VNhySp3iSiso8Jof27lku4pyuvRV
|
||||
ko/B3N2H7LnQrGV0GyrjeYocW/qZh/PCsY48JBFhlNQexn2mn44AJW3y5xgbhvKA
|
||||
3mrmMD1hD17ZvZxi4fPHjbuAyM1vFqhQx63eT9ijbwJ91svKJl5O5MIv41mCRonm
|
||||
hxvOXw8S0mjSasyofptzzQCtXxFLQigXbpQBltII+Ys=
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
# encrypted with the passphrase 'testxp'. NB: this key was generated by
|
||||
# OpenSSH, so it doesn't use the same key data as the other keys here.
|
||||
privateRSA_openssh_encrypted_aes = """-----BEGIN RSA PRIVATE KEY-----
|
||||
Proc-Type: 4,ENCRYPTED
|
||||
DEK-Info: AES-128-CBC,0673309A6ACCAB4B77DEE1C1E536AC26
|
||||
|
||||
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
|
||||
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
|
||||
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
|
||||
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
|
||||
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
|
||||
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
|
||||
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
|
||||
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
|
||||
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
|
||||
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
|
||||
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
|
||||
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
|
||||
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
|
||||
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
|
||||
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
|
||||
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
|
||||
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
|
||||
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
|
||||
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
|
||||
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
|
||||
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
|
||||
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
|
||||
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
|
||||
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
|
||||
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
|
||||
-----END RSA PRIVATE KEY-----"""
|
||||
|
||||
publicRSA_lsh = ("{KDEwOnB1YmxpYy1rZXkoMTQ6cnNhLXBrY3MxLXNoYTEoMTpuOTc6AK8yc"
|
||||
"fDmDpyZs3+LXwRLy4vA1T6yd/3PZNiPwM+uH8Yx3/YpskSW4sbUIZR/ZXzY1CMfuC5qyR+UDUbB"
|
||||
"aaK3Bwyjk8E02C4eSpkabJZGB0Yr3CUpG4fwvgUd7rQ0ueeZlSkoMTplMTojKSkp}")
|
||||
|
||||
privateRSA_lsh = ("(11:private-key(9:rsa-pkcs1(1:n97:\x00\xaf2q\xf0\xe6\x0e"
|
||||
"\x9c\x99\xb3\x7f\x8b_\x04K\xcb\x8b\xc0\xd5>\xb2w\xfd\xcfd\xd8\x8f\xc0\xcf"
|
||||
"\xae\x1f\xc61\xdf\xf6)\xb2D\x96\xe2\xc6\xd4!\x94\x7fe|\xd8\xd4#\x1f\xb8.j"
|
||||
"\xc9\x1f\x94\rF\xc1i\xa2\xb7\x07\x0c\xa3\x93\xc14\xd8.\x1eJ\x99\x1al\x96F"
|
||||
"\x07F+\xdc%)\x1b\x87\xf0\xbe\x05\x1d\xee\xb44\xb9\xe7\x99\x95)(1:e1:#)(1:d9"
|
||||
"6:n\x1f\xb5U\x97\xeb\xedg\xed+\x99n\xec\xc1\xed\xa8MR\xd6\xf3\xd6e\x06\x04"
|
||||
"\xdf\xe5T\x9f\xcc\x89\x00<\x9bg\x87\xece\xa0\xab\xcdoe\x90\x8a\x97\x90M\xc6"
|
||||
'!\x8f\xa8\x8d\xd8Y\x86C\xb5\x81\xb1\xb4\xd7_,"\na\xc1%\x8aG\x12\xb4\x9a\xf8'
|
||||
"z\x11\x1cJ\xa8\x8bu\xc4\x91\t;\xbe\x04\xcaE\xd9W\x8a\r\'\xcb#)(1:p49:\x00"
|
||||
"\xdc\x9fk\xd9\x98!V\x11\x8d\xe9_\x03\x9d\n\xd3\x93n\x13wA<\x85O\x00p\xfd"
|
||||
"\x05T\xff\xbc=\t\xbf\x83\xf6\x97\x7fd\x10\x91\x04\xfe\xa2gGTBk)(1:q49:\x00"
|
||||
"\xcbJK\xd0@G\xe8ER\xf7\xc7\xaf\x0c mC\r\xb69\x94\xf9\xda\xa5\xe5\x03\x06v"
|
||||
"\x83$\xeb\x88\xa1U\xa2\xa8\xde\x12;wI\x92\x8a\xa9q\xd2\x02\x93\xff)(1:a48:K"
|
||||
"\xa4_}\xcd\xc2Ie\x1a\xb6i\xb8\x18\x95\xffe\xbfW!\x92\xb5\xaa\x0cu.\r\x9bm"
|
||||
"\x99\x82^\x11\xf8\x85\x04\x16\xafU\x82\x05\xd5\xd3\xa5e=\x06\xf23)(1:b48:"
|
||||
"\x17;\xb0\xe4\x99\xa1\xd1g\x02*\xf2?\xe4 \xf6\x8bR\x062wm\x03\x0b\xa5$\xea"
|
||||
"\xcb\xb77k`\x12p/\xd8\xc8\xec$\r\xa2\x02\x1ey\xc3\xdd|\xa33)(1:c49:\x00\xb4"
|
||||
"s\x97KP\x10\xa3\x17\xb3\xa8G\xf1:\x14vR\xd18*\xcf\x12\x144\xc1\xa8TL)5\x80"
|
||||
"\xa08\xb8\xf0\xfaL\xc4\xc2\x85\xab\xdb\x87\x82\xba\xdc\xeb\xdb*)))")
|
||||
|
||||
privateRSA_agentv3 = ("\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x01#\x00\x00\x00`"
|
||||
"n\x1f\xb5U\x97\xeb\xedg\xed+\x99n\xec\xc1\xed\xa8MR\xd6\xf3\xd6e\x06\x04"
|
||||
"\xdf\xe5T\x9f\xcc\x89\x00<\x9bg\x87\xece\xa0\xab\xcdoe\x90\x8a\x97\x90M\xc6"
|
||||
'!\x8f\xa8\x8d\xd8Y\x86C\xb5\x81\xb1\xb4\xd7_,"\na\xc1%\x8aG\x12\xb4\x9a\xf8'
|
||||
"z\x11\x1cJ\xa8\x8bu\xc4\x91\t;\xbe\x04\xcaE\xd9W\x8a\r\'\xcb#\x00\x00\x00a"
|
||||
"\x00\xaf2q\xf0\xe6\x0e\x9c\x99\xb3\x7f\x8b_\x04K\xcb\x8b\xc0\xd5>\xb2w\xfd"
|
||||
"\xcfd\xd8\x8f\xc0\xcf\xae\x1f\xc61\xdf\xf6)\xb2D\x96\xe2\xc6\xd4!\x94\x7fe|"
|
||||
"\xd8\xd4#\x1f\xb8.j\xc9\x1f\x94\rF\xc1i\xa2\xb7\x07\x0c\xa3\x93\xc14\xd8."
|
||||
"\x1eJ\x99\x1al\x96F\x07F+\xdc%)\x1b\x87\xf0\xbe\x05\x1d\xee\xb44\xb9\xe7"
|
||||
"\x99\x95\x00\x00\x001\x00\xb4s\x97KP\x10\xa3\x17\xb3\xa8G\xf1:\x14vR\xd18*"
|
||||
"\xcf\x12\x144\xc1\xa8TL)5\x80\xa08\xb8\xf0\xfaL\xc4\xc2\x85\xab\xdb\x87\x82"
|
||||
"\xba\xdc\xeb\xdb*\x00\x00\x001\x00\xcbJK\xd0@G\xe8ER\xf7\xc7\xaf\x0c mC\r"
|
||||
"\xb69\x94\xf9\xda\xa5\xe5\x03\x06v\x83$\xeb\x88\xa1U\xa2\xa8\xde\x12;wI\x92"
|
||||
"\x8a\xa9q\xd2\x02\x93\xff\x00\x00\x001\x00\xdc\x9fk\xd9\x98!V\x11\x8d\xe9_"
|
||||
"\x03\x9d\n\xd3\x93n\x13wA<\x85O\x00p\xfd\x05T\xff\xbc=\t\xbf\x83\xf6\x97"
|
||||
"\x7fd\x10\x91\x04\xfe\xa2gGTBk")
|
||||
|
||||
publicDSA_openssh = ("ssh-dss AAAAB3NzaC1kc3MAAABBAIbwTOSsZ7Bl7U1KyMNqV13Tu7"
|
||||
"yRAtTr70PVI3QnfrPumf2UzCgpL1ljbKxSfAi05XvrE/1vfCFAsFYXRZLhQy0AAAAVAM965Akmo"
|
||||
"6eAi7K+k9qDR4TotFAXAAAAQADZlpTW964haQWS4vC063NGdldT6xpUGDcDRqbm90CoPEa2RmNO"
|
||||
"uOqi8lnbhYraEzypYH3K4Gzv/bxCBnKtHRUAAABAK+1osyWBS0+P90u/rAuko6chZ98thUSY2kL"
|
||||
"SHp6hLKyy2bjnT29h7haELE+XHfq2bM9fckDx2FLOSIJzy83VmQ== comment")
|
||||
|
||||
privateDSA_openssh = """-----BEGIN DSA PRIVATE KEY-----
|
||||
MIH4AgEAAkEAhvBM5KxnsGXtTUrIw2pXXdO7vJEC1OvvQ9UjdCd+s+6Z/ZTMKCkv
|
||||
WWNsrFJ8CLTle+sT/W98IUCwVhdFkuFDLQIVAM965Akmo6eAi7K+k9qDR4TotFAX
|
||||
AkAA2ZaU1veuIWkFkuLwtOtzRnZXU+saVBg3A0am5vdAqDxGtkZjTrjqovJZ24WK
|
||||
2hM8qWB9yuBs7/28QgZyrR0VAkAr7WizJYFLT4/3S7+sC6SjpyFn3y2FRJjaQtIe
|
||||
nqEsrLLZuOdPb2HuFoQsT5cd+rZsz19yQPHYUs5IgnPLzdWZAhUAl1TqdmlAG/b4
|
||||
nnVchGiO9sML8MM=
|
||||
-----END DSA PRIVATE KEY-----"""
|
||||
|
||||
publicDSA_lsh = ("{KDEwOnB1YmxpYy1rZXkoMzpkc2EoMTpwNjU6AIbwTOSsZ7Bl7U1KyMNqV"
|
||||
"13Tu7yRAtTr70PVI3QnfrPumf2UzCgpL1ljbKxSfAi05XvrE/1vfCFAsFYXRZLhQy0pKDE6cTIx"
|
||||
"OgDPeuQJJqOngIuyvpPag0eE6LRQFykoMTpnNjQ6ANmWlNb3riFpBZLi8LTrc0Z2V1PrGlQYNwN"
|
||||
"Gpub3QKg8RrZGY0646qLyWduFitoTPKlgfcrgbO/9vEIGcq0dFSkoMTp5NjQ6K+1osyWBS0+P90"
|
||||
"u/rAuko6chZ98thUSY2kLSHp6hLKyy2bjnT29h7haELE+XHfq2bM9fckDx2FLOSIJzy83VmSkpK"
|
||||
"Q==}")
|
||||
|
||||
privateDSA_lsh = ("(11:private-key(3:dsa(1:p65:\x00\x86\xf0L\xe4\xacg\xb0e"
|
||||
"\xedMJ\xc8\xc3jW]\xd3\xbb\xbc\x91\x02\xd4\xeb\xefC\xd5#t'~\xb3\xee\x99\xfd"
|
||||
"\x94\xcc()/Ycl\xacR|\x08\xb4\xe5{\xeb\x13\xfdo|!@\xb0V\x17E\x92\xe1C-)(1:q2"
|
||||
"1:\x00\xcfz\xe4\t&\xa3\xa7\x80\x8b\xb2\xbe\x93\xda\x83G\x84\xe8\xb4P\x17)(1"
|
||||
":g64:\x00\xd9\x96\x94\xd6\xf7\xae!i\x05\x92\xe2\xf0\xb4\xebsFvWS\xeb\x1aT"
|
||||
"\x187\x03F\xa6\xe6\xf7@\xa8<F\xb6FcN\xb8\xea\xa2\xf2Y\xdb\x85\x8a\xda\x13<"
|
||||
"\xa9`}\xca\xe0l\xef\xfd\xbcB\x06r\xad\x1d\x15)(1:y64:+\xedh\xb3%\x81KO\x8f"
|
||||
"\xf7K\xbf\xac\x0b\xa4\xa3\xa7!g\xdf-\x85D\x98\xdaB\xd2\x1e\x9e\xa1,\xac\xb2"
|
||||
"\xd9\xb8\xe7Ooa\xee\x16\x84,O\x97\x1d\xfa\xb6l\xcf_r@\xf1\xd8R\xceH\x82s"
|
||||
"\xcb\xcd\xd5\x99)(1:x21:\x00\x97T\xeavi@\x1b\xf6\xf8\x9eu\\\x84h\x8e\xf6"
|
||||
"\xc3\x0b\xf0\xc3)))")
|
||||
|
||||
privateDSA_agentv3 = ("\x00\x00\x00\x07ssh-dss\x00\x00\x00A\x00\x86\xf0L\xe4"
|
||||
"\xacg\xb0e\xedMJ\xc8\xc3jW]\xd3\xbb\xbc\x91\x02\xd4\xeb\xefC\xd5#t'~\xb3"
|
||||
"\xee\x99\xfd\x94\xcc()/Ycl\xacR|\x08\xb4\xe5{\xeb\x13\xfdo|!@\xb0V\x17E\x92"
|
||||
"\xe1C-\x00\x00\x00\x15\x00\xcfz\xe4\t&\xa3\xa7\x80\x8b\xb2\xbe\x93\xda\x83G"
|
||||
"\x84\xe8\xb4P\x17\x00\x00\x00@\x00\xd9\x96\x94\xd6\xf7\xae!i\x05\x92\xe2"
|
||||
"\xf0\xb4\xebsFvWS\xeb\x1aT\x187\x03F\xa6\xe6\xf7@\xa8<F\xb6FcN\xb8\xea\xa2"
|
||||
"\xf2Y\xdb\x85\x8a\xda\x13<\xa9`}\xca\xe0l\xef\xfd\xbcB\x06r\xad\x1d\x15\x00"
|
||||
"\x00\x00@+\xedh\xb3%\x81KO\x8f\xf7K\xbf\xac\x0b\xa4\xa3\xa7!g\xdf-\x85D\x98"
|
||||
"\xdaB\xd2\x1e\x9e\xa1,\xac\xb2\xd9\xb8\xe7Ooa\xee\x16\x84,O\x97\x1d\xfa\xb6"
|
||||
"l\xcf_r@\xf1\xd8R\xceH\x82s\xcb\xcd\xd5\x99\x00\x00\x00\x15\x00\x97T\xeavi@"
|
||||
"\x1b\xf6\xf8\x9eu\\\x84h\x8e\xf6\xc3\x0b\xf0\xc3")
|
||||
|
||||
__all__ = ['DSAData', 'RSAData', 'privateDSA_agentv3', 'privateDSA_lsh',
|
||||
'privateDSA_openssh', 'privateRSA_agentv3', 'privateRSA_lsh',
|
||||
'privateRSA_openssh', 'publicDSA_lsh', 'publicDSA_openssh',
|
||||
'publicRSA_lsh', 'publicRSA_openssh', 'privateRSA_openssh_alternate']
|
||||
|
@ -0,0 +1,49 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{SSHTransportAddrress} in ssh/address.py
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.test.test_address import AddressTestCaseMixin
|
||||
|
||||
from twisted.conch.ssh.address import SSHTransportAddress
|
||||
|
||||
|
||||
|
||||
class SSHTransportAddressTests(unittest.TestCase, AddressTestCaseMixin):
|
||||
"""
|
||||
L{twisted.conch.ssh.address.SSHTransportAddress} is what Conch transports
|
||||
use to represent the other side of the SSH connection. This tests the
|
||||
basic functionality of that class (string representation, comparison, &c).
|
||||
"""
|
||||
|
||||
|
||||
def _stringRepresentation(self, stringFunction):
|
||||
"""
|
||||
The string representation of C{SSHTransportAddress} should be
|
||||
"SSHTransportAddress(<stringFunction on address>)".
|
||||
"""
|
||||
addr = self.buildAddress()
|
||||
stringValue = stringFunction(addr)
|
||||
addressValue = stringFunction(addr.address)
|
||||
self.assertEqual(stringValue,
|
||||
"SSHTransportAddress(%s)" % addressValue)
|
||||
|
||||
|
||||
def buildAddress(self):
|
||||
"""
|
||||
Create an arbitrary new C{SSHTransportAddress}. A new instance is
|
||||
created for each call, but always for the same address.
|
||||
"""
|
||||
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.1", 22))
|
||||
|
||||
|
||||
def buildDifferentAddress(self):
|
||||
"""
|
||||
Like C{buildAddress}, but with a different fixed address.
|
||||
"""
|
||||
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.2", 22))
|
||||
|
@ -0,0 +1,399 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.ssh.agent}.
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.internet import reactor
|
||||
from twisted.internet.interfaces import ITLSTransport
|
||||
from twisted.trial import unittest
|
||||
|
||||
if not ITLSTransport.providedBy(reactor):
|
||||
iosim = None
|
||||
else:
|
||||
from twisted.test import iosim
|
||||
|
||||
try:
|
||||
import Crypto.Cipher.DES3
|
||||
except ImportError:
|
||||
Crypto = None
|
||||
|
||||
try:
|
||||
import pyasn1
|
||||
except ImportError:
|
||||
pyasn1 = None
|
||||
|
||||
if Crypto and pyasn1:
|
||||
from twisted.conch.ssh import keys, agent
|
||||
else:
|
||||
keys = agent = None
|
||||
|
||||
from twisted.conch.test import keydata
|
||||
from twisted.conch.error import ConchError, MissingKeyStoreError
|
||||
|
||||
|
||||
class StubFactory(object):
|
||||
"""
|
||||
Mock factory that provides the keys attribute required by the
|
||||
SSHAgentServerProtocol
|
||||
"""
|
||||
def __init__(self):
|
||||
self.keys = {}
|
||||
|
||||
|
||||
|
||||
class AgentTestBase(unittest.TestCase):
|
||||
"""
|
||||
Tests for SSHAgentServer/Client.
|
||||
"""
|
||||
if iosim is None:
|
||||
skip = "iosim requires SSL, but SSL is not available"
|
||||
elif agent is None or keys is None:
|
||||
skip = "Cannot run without PyCrypto or PyASN1"
|
||||
|
||||
def setUp(self):
|
||||
# wire up our client <-> server
|
||||
self.client, self.server, self.pump = iosim.connectedServerAndClient(
|
||||
agent.SSHAgentServer, agent.SSHAgentClient)
|
||||
|
||||
# the server's end of the protocol is stateful and we store it on the
|
||||
# factory, for which we only need a mock
|
||||
self.server.factory = StubFactory()
|
||||
|
||||
# pub/priv keys of each kind
|
||||
self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
|
||||
self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
|
||||
|
||||
self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
|
||||
self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
|
||||
|
||||
|
||||
|
||||
class ServerProtocolContractWithFactoryTests(AgentTestBase):
|
||||
"""
|
||||
The server protocol is stateful and so uses its factory to track state
|
||||
across requests. This test asserts that the protocol raises if its factory
|
||||
doesn't provide the necessary storage for that state.
|
||||
"""
|
||||
def test_factorySuppliesKeyStorageForServerProtocol(self):
|
||||
# need a message to send into the server
|
||||
msg = struct.pack('!LB',1, agent.AGENTC_REQUEST_IDENTITIES)
|
||||
del self.server.factory.__dict__['keys']
|
||||
self.assertRaises(MissingKeyStoreError,
|
||||
self.server.dataReceived, msg)
|
||||
|
||||
|
||||
|
||||
class UnimplementedVersionOneServerTests(AgentTestBase):
|
||||
"""
|
||||
Tests for methods with no-op implementations on the server. We need these
|
||||
for clients, such as openssh, that try v1 methods before going to v2.
|
||||
|
||||
Because the client doesn't expose these operations with nice method names,
|
||||
we invoke sendRequest directly with an op code.
|
||||
"""
|
||||
|
||||
def test_agentc_REQUEST_RSA_IDENTITIES(self):
|
||||
"""
|
||||
assert that we get the correct op code for an RSA identities request
|
||||
"""
|
||||
d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, '')
|
||||
self.pump.flush()
|
||||
def _cb(packet):
|
||||
self.assertEqual(
|
||||
agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0]))
|
||||
return d.addCallback(_cb)
|
||||
|
||||
|
||||
def test_agentc_REMOVE_RSA_IDENTITY(self):
|
||||
"""
|
||||
assert that we get the correct op code for an RSA remove identity request
|
||||
"""
|
||||
d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, '')
|
||||
self.pump.flush()
|
||||
return d.addCallback(self.assertEqual, '')
|
||||
|
||||
|
||||
def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
|
||||
"""
|
||||
assert that we get the correct op code for an RSA remove all identities
|
||||
request.
|
||||
"""
|
||||
d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, '')
|
||||
self.pump.flush()
|
||||
return d.addCallback(self.assertEqual, '')
|
||||
|
||||
|
||||
|
||||
if agent is not None:
|
||||
class CorruptServer(agent.SSHAgentServer):
|
||||
"""
|
||||
A misbehaving server that returns bogus response op codes so that we can
|
||||
verify that our callbacks that deal with these op codes handle such
|
||||
miscreants.
|
||||
"""
|
||||
def agentc_REQUEST_IDENTITIES(self, data):
|
||||
self.sendResponse(254, '')
|
||||
|
||||
|
||||
def agentc_SIGN_REQUEST(self, data):
|
||||
self.sendResponse(254, '')
|
||||
|
||||
|
||||
|
||||
class ClientWithBrokenServerTests(AgentTestBase):
|
||||
"""
|
||||
verify error handling code in the client using a misbehaving server
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
AgentTestBase.setUp(self)
|
||||
self.client, self.server, self.pump = iosim.connectedServerAndClient(
|
||||
CorruptServer, agent.SSHAgentClient)
|
||||
# the server's end of the protocol is stateful and we store it on the
|
||||
# factory, for which we only need a mock
|
||||
self.server.factory = StubFactory()
|
||||
|
||||
|
||||
def test_signDataCallbackErrorHandling(self):
|
||||
"""
|
||||
Assert that L{SSHAgentClient.signData} raises a ConchError
|
||||
if we get a response from the server whose opcode doesn't match
|
||||
the protocol for data signing requests.
|
||||
"""
|
||||
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
def test_requestIdentitiesCallbackErrorHandling(self):
|
||||
"""
|
||||
Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
|
||||
if we get a response from the server whose opcode doesn't match
|
||||
the protocol for identity requests.
|
||||
"""
|
||||
d = self.client.requestIdentities()
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
|
||||
class AgentKeyAdditionTests(AgentTestBase):
|
||||
"""
|
||||
Test adding different flavors of keys to an agent.
|
||||
"""
|
||||
|
||||
def test_addRSAIdentityNoComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that omitting the comment produces an
|
||||
empty string for the comment on the server.
|
||||
"""
|
||||
d = self.client.addIdentity(self.rsaPrivate.privateBlob())
|
||||
self.pump.flush()
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
|
||||
self.assertEqual(self.rsaPrivate, serverKey[0])
|
||||
self.assertEqual('', serverKey[1])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_addDSAIdentityNoComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that omitting the comment produces an
|
||||
empty string for the comment on the server.
|
||||
"""
|
||||
d = self.client.addIdentity(self.dsaPrivate.privateBlob())
|
||||
self.pump.flush()
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
|
||||
self.assertEqual(self.dsaPrivate, serverKey[0])
|
||||
self.assertEqual('', serverKey[1])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_addRSAIdentityWithComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that the server receives/stores the comment
|
||||
as sent by the client.
|
||||
"""
|
||||
d = self.client.addIdentity(
|
||||
self.rsaPrivate.privateBlob(), comment='My special key')
|
||||
self.pump.flush()
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
|
||||
self.assertEqual(self.rsaPrivate, serverKey[0])
|
||||
self.assertEqual('My special key', serverKey[1])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_addDSAIdentityWithComment(self):
|
||||
"""
|
||||
L{SSHAgentClient.addIdentity} adds the private key it is called
|
||||
with to the SSH agent server to which it is connected, associating
|
||||
it with the comment it is called with.
|
||||
|
||||
This test asserts that the server receives/stores the comment
|
||||
as sent by the client.
|
||||
"""
|
||||
d = self.client.addIdentity(
|
||||
self.dsaPrivate.privateBlob(), comment='My special key')
|
||||
self.pump.flush()
|
||||
def _check(ignored):
|
||||
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
|
||||
self.assertEqual(self.dsaPrivate, serverKey[0])
|
||||
self.assertEqual('My special key', serverKey[1])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
|
||||
class AgentClientFailureTests(AgentTestBase):
|
||||
def test_agentFailure(self):
|
||||
"""
|
||||
verify that the client raises ConchError on AGENT_FAILURE
|
||||
"""
|
||||
d = self.client.sendRequest(254, '')
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
|
||||
class AgentIdentityRequestsTests(AgentTestBase):
|
||||
"""
|
||||
Test operations against a server with identities already loaded.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
AgentTestBase.setUp(self)
|
||||
self.server.factory.keys[self.dsaPrivate.blob()] = (
|
||||
self.dsaPrivate, 'a comment')
|
||||
self.server.factory.keys[self.rsaPrivate.blob()] = (
|
||||
self.rsaPrivate, 'another comment')
|
||||
|
||||
|
||||
def test_signDataRSA(self):
|
||||
"""
|
||||
Sign data with an RSA private key and then verify it with the public
|
||||
key.
|
||||
"""
|
||||
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
|
||||
self.pump.flush()
|
||||
def _check(sig):
|
||||
expected = self.rsaPrivate.sign("John Hancock")
|
||||
self.assertEqual(expected, sig)
|
||||
self.assertTrue(self.rsaPublic.verify(sig, "John Hancock"))
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_signDataDSA(self):
|
||||
"""
|
||||
Sign data with a DSA private key and then verify it with the public
|
||||
key.
|
||||
"""
|
||||
d = self.client.signData(self.dsaPublic.blob(), "John Hancock")
|
||||
self.pump.flush()
|
||||
def _check(sig):
|
||||
# Cannot do this b/c DSA uses random numbers when signing
|
||||
# expected = self.dsaPrivate.sign("John Hancock")
|
||||
# self.assertEqual(expected, sig)
|
||||
self.assertTrue(self.dsaPublic.verify(sig, "John Hancock"))
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_signDataRSAErrbackOnUnknownBlob(self):
|
||||
"""
|
||||
Assert that we get an errback if we try to sign data using a key that
|
||||
wasn't added.
|
||||
"""
|
||||
del self.server.factory.keys[self.rsaPublic.blob()]
|
||||
d = self.client.signData(self.rsaPublic.blob(), "John Hancock")
|
||||
self.pump.flush()
|
||||
return self.assertFailure(d, ConchError)
|
||||
|
||||
|
||||
def test_requestIdentities(self):
|
||||
"""
|
||||
Assert that we get all of the keys/comments that we add when we issue a
|
||||
request for all identities.
|
||||
"""
|
||||
d = self.client.requestIdentities()
|
||||
self.pump.flush()
|
||||
def _check(keyt):
|
||||
expected = {}
|
||||
expected[self.dsaPublic.blob()] = 'a comment'
|
||||
expected[self.rsaPublic.blob()] = 'another comment'
|
||||
|
||||
received = {}
|
||||
for k in keyt:
|
||||
received[keys.Key.fromString(k[0], type='blob').blob()] = k[1]
|
||||
self.assertEqual(expected, received)
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
|
||||
class AgentKeyRemovalTests(AgentTestBase):
|
||||
"""
|
||||
Test support for removing keys in a remote server.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
AgentTestBase.setUp(self)
|
||||
self.server.factory.keys[self.dsaPrivate.blob()] = (
|
||||
self.dsaPrivate, 'a comment')
|
||||
self.server.factory.keys[self.rsaPrivate.blob()] = (
|
||||
self.rsaPrivate, 'another comment')
|
||||
|
||||
|
||||
def test_removeRSAIdentity(self):
|
||||
"""
|
||||
Assert that we can remove an RSA identity.
|
||||
"""
|
||||
# only need public key for this
|
||||
d = self.client.removeIdentity(self.rsaPrivate.blob())
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
self.assertEqual(1, len(self.server.factory.keys))
|
||||
self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
|
||||
self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_removeDSAIdentity(self):
|
||||
"""
|
||||
Assert that we can remove a DSA identity.
|
||||
"""
|
||||
# only need public key for this
|
||||
d = self.client.removeIdentity(self.dsaPrivate.blob())
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
self.assertEqual(1, len(self.server.factory.keys))
|
||||
self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def test_removeAllIdentities(self):
|
||||
"""
|
||||
Assert that we can remove all identities.
|
||||
"""
|
||||
d = self.client.removeAllIdentities()
|
||||
self.pump.flush()
|
||||
|
||||
def _check(ignored):
|
||||
self.assertEqual(0, len(self.server.factory.keys))
|
||||
return d.addCallback(_check)
|
@ -0,0 +1,992 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_cftp -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE file for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.scripts.cftp}.
|
||||
"""
|
||||
|
||||
import locale
|
||||
import time, sys, os, operator, getpass, struct
|
||||
from StringIO import StringIO
|
||||
|
||||
from twisted.conch.test.test_ssh import Crypto, pyasn1
|
||||
|
||||
_reason = None
|
||||
if Crypto and pyasn1:
|
||||
try:
|
||||
from twisted.conch import unix
|
||||
from twisted.conch.scripts import cftp
|
||||
from twisted.conch.scripts.cftp import SSHSession
|
||||
from twisted.conch.test.test_filetransfer import FileTransferForTestAvatar
|
||||
except ImportError as e:
|
||||
unix = None
|
||||
_reason = str(e)
|
||||
del e
|
||||
else:
|
||||
unix = None
|
||||
|
||||
from twisted.python.fakepwd import UserDatabase
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.cred import portal
|
||||
from twisted.internet import reactor, protocol, interfaces, defer, error
|
||||
from twisted.internet.utils import getProcessOutputAndValue
|
||||
from twisted.python import log
|
||||
from twisted.conch import ls
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
from twisted.internet.task import Clock
|
||||
|
||||
from twisted.conch.test import test_ssh, test_conch
|
||||
from twisted.conch.test.test_filetransfer import SFTPTestBase
|
||||
from twisted.conch.test.test_filetransfer import FileTransferTestAvatar
|
||||
from twisted.conch.test.test_conch import FakeStdio
|
||||
|
||||
|
||||
|
||||
class SSHSessionTests(TestCase):
|
||||
"""
|
||||
Tests for L{twisted.conch.scripts.cftp.SSHSession}.
|
||||
"""
|
||||
def test_eofReceived(self):
|
||||
"""
|
||||
L{twisted.conch.scripts.cftp.SSHSession.eofReceived} loses the write
|
||||
half of its stdio connection.
|
||||
"""
|
||||
stdio = FakeStdio()
|
||||
channel = SSHSession()
|
||||
channel.stdio = stdio
|
||||
channel.eofReceived()
|
||||
self.assertTrue(stdio.writeConnLost)
|
||||
|
||||
|
||||
|
||||
class ListingTests(TestCase):
|
||||
"""
|
||||
Tests for L{lsLine}, the function which generates an entry for a file or
|
||||
directory in an SFTP I{ls} command's output.
|
||||
"""
|
||||
if getattr(time, 'tzset', None) is None:
|
||||
skip = "Cannot test timestamp formatting code without time.tzset"
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Patch the L{ls} module's time function so the results of L{lsLine} are
|
||||
deterministic.
|
||||
"""
|
||||
self.now = 123456789
|
||||
def fakeTime():
|
||||
return self.now
|
||||
self.patch(ls, 'time', fakeTime)
|
||||
|
||||
# Make sure that the timezone ends up the same after these tests as
|
||||
# it was before.
|
||||
if 'TZ' in os.environ:
|
||||
self.addCleanup(operator.setitem, os.environ, 'TZ', os.environ['TZ'])
|
||||
self.addCleanup(time.tzset)
|
||||
else:
|
||||
def cleanup():
|
||||
# os.environ.pop is broken! Don't use it! Ever! Or die!
|
||||
try:
|
||||
del os.environ['TZ']
|
||||
except KeyError:
|
||||
pass
|
||||
time.tzset()
|
||||
self.addCleanup(cleanup)
|
||||
|
||||
|
||||
def _lsInTimezone(self, timezone, stat):
|
||||
"""
|
||||
Call L{ls.lsLine} after setting the timezone to C{timezone} and return
|
||||
the result.
|
||||
"""
|
||||
# Set the timezone to a well-known value so the timestamps are
|
||||
# predictable.
|
||||
os.environ['TZ'] = timezone
|
||||
time.tzset()
|
||||
return ls.lsLine('foo', stat)
|
||||
|
||||
|
||||
def test_oldFile(self):
|
||||
"""
|
||||
A file with an mtime six months (approximately) or more in the past has
|
||||
a listing including a low-resolution timestamp.
|
||||
"""
|
||||
# Go with 7 months. That's more than 6 months.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 7)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 Apr 26 1973 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 Apr 27 1973 foo')
|
||||
|
||||
|
||||
def test_oldSingleDigitDayOfMonth(self):
|
||||
"""
|
||||
A file with a high-resolution timestamp which falls on a day of the
|
||||
month which can be represented by one decimal digit is formatted with
|
||||
one padding 0 to preserve the columns which come after it.
|
||||
"""
|
||||
# A point about 7 months in the past, tweaked to fall on the first of a
|
||||
# month so we test the case we want to test.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 7) + (60 * 60 * 24 * 5)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 May 01 1973 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 May 02 1973 foo')
|
||||
|
||||
|
||||
def test_newFile(self):
|
||||
"""
|
||||
A file with an mtime fewer than six months (approximately) in the past
|
||||
has a listing including a high-resolution timestamp excluding the year.
|
||||
"""
|
||||
# A point about three months in the past.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 3)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 Aug 28 17:33 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 Aug 29 09:33 foo')
|
||||
|
||||
|
||||
def test_localeIndependent(self):
|
||||
"""
|
||||
The month name in the date is locale independent.
|
||||
"""
|
||||
# A point about three months in the past.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 3)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
# Fake that we're in a language where August is not Aug (e.g.: Spanish)
|
||||
currentLocale = locale.getlocale()
|
||||
locale.setlocale(locale.LC_ALL, "es_AR.UTF8")
|
||||
self.addCleanup(locale.setlocale, locale.LC_ALL, currentLocale)
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 Aug 28 17:33 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 Aug 29 09:33 foo')
|
||||
|
||||
# if alternate locale is not available, the previous test will be
|
||||
# skipped, please install this locale for it to run
|
||||
currentLocale = locale.getlocale()
|
||||
try:
|
||||
try:
|
||||
locale.setlocale(locale.LC_ALL, "es_AR.UTF8")
|
||||
except locale.Error:
|
||||
test_localeIndependent.skip = "The es_AR.UTF8 locale is not installed."
|
||||
finally:
|
||||
locale.setlocale(locale.LC_ALL, currentLocale)
|
||||
|
||||
|
||||
def test_newSingleDigitDayOfMonth(self):
|
||||
"""
|
||||
A file with a high-resolution timestamp which falls on a day of the
|
||||
month which can be represented by one decimal digit is formatted with
|
||||
one padding 0 to preserve the columns which come after it.
|
||||
"""
|
||||
# A point about three months in the past, tweaked to fall on the first
|
||||
# of a month so we test the case we want to test.
|
||||
then = self.now - (60 * 60 * 24 * 31 * 3) + (60 * 60 * 24 * 4)
|
||||
stat = os.stat_result((0, 0, 0, 0, 0, 0, 0, 0, then, 0))
|
||||
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('America/New_York', stat),
|
||||
'!--------- 0 0 0 0 Sep 01 17:33 foo')
|
||||
self.assertEqual(
|
||||
self._lsInTimezone('Pacific/Auckland', stat),
|
||||
'!--------- 0 0 0 0 Sep 02 09:33 foo')
|
||||
|
||||
|
||||
|
||||
class StdioClientTests(TestCase):
|
||||
"""
|
||||
Tests for L{cftp.StdioClient}.
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
Create a L{cftp.StdioClient} hooked up to dummy transport and a fake
|
||||
user database.
|
||||
"""
|
||||
class Connection:
|
||||
pass
|
||||
|
||||
conn = Connection()
|
||||
conn.transport = StringTransport()
|
||||
conn.transport.localClosed = False
|
||||
|
||||
self.client = cftp.StdioClient(conn)
|
||||
self.database = self.client._pwd = UserDatabase()
|
||||
|
||||
# Intentionally bypassing makeConnection - that triggers some code
|
||||
# which uses features not provided by our dumb Connection fake.
|
||||
self.client.transport = StringTransport()
|
||||
|
||||
|
||||
def test_exec(self):
|
||||
"""
|
||||
The I{exec} command runs its arguments locally in a child process
|
||||
using the user's shell.
|
||||
"""
|
||||
self.database.addUser(
|
||||
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar',
|
||||
sys.executable)
|
||||
|
||||
d = self.client._dispatchCommand("exec print 1 + 2")
|
||||
d.addCallback(self.assertEqual, "3\n")
|
||||
return d
|
||||
|
||||
|
||||
def test_execWithoutShell(self):
|
||||
"""
|
||||
If the local user has no shell, the I{exec} command runs its arguments
|
||||
using I{/bin/sh}.
|
||||
"""
|
||||
self.database.addUser(
|
||||
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar', '')
|
||||
|
||||
d = self.client._dispatchCommand("exec echo hello")
|
||||
d.addCallback(self.assertEqual, "hello\n")
|
||||
return d
|
||||
|
||||
|
||||
def test_bang(self):
|
||||
"""
|
||||
The I{exec} command is run for lines which start with C{"!"}.
|
||||
"""
|
||||
self.database.addUser(
|
||||
getpass.getuser(), 'secret', os.getuid(), 1234, 'foo', 'bar',
|
||||
'/bin/sh')
|
||||
|
||||
d = self.client._dispatchCommand("!echo hello")
|
||||
d.addCallback(self.assertEqual, "hello\n")
|
||||
return d
|
||||
|
||||
|
||||
def setKnownConsoleSize(self, width, height):
|
||||
"""
|
||||
For the duration of this test, patch C{cftp}'s C{fcntl} module to return
|
||||
a fixed width and height.
|
||||
|
||||
@param width: the width in characters
|
||||
@type width: C{int}
|
||||
@param height: the height in characters
|
||||
@type height: C{int}
|
||||
"""
|
||||
import tty # local import to avoid win32 issues
|
||||
class FakeFcntl(object):
|
||||
def ioctl(self, fd, opt, mutate):
|
||||
if opt != tty.TIOCGWINSZ:
|
||||
self.fail("Only window-size queries supported.")
|
||||
return struct.pack("4H", height, width, 0, 0)
|
||||
self.patch(cftp, "fcntl", FakeFcntl())
|
||||
|
||||
|
||||
def test_progressReporting(self):
|
||||
"""
|
||||
L{StdioClient._printProgressBar} prints a progress description,
|
||||
including percent done, amount transferred, transfer rate, and time
|
||||
remaining, all based the given start time, the given L{FileWrapper}'s
|
||||
progress information and the reactor's current time.
|
||||
"""
|
||||
# Use a short, known console width because this simple test doesn't need
|
||||
# to test the console padding.
|
||||
self.setKnownConsoleSize(10, 34)
|
||||
clock = self.client.reactor = Clock()
|
||||
wrapped = StringIO("x")
|
||||
wrapped.name = "sample"
|
||||
wrapper = cftp.FileWrapper(wrapped)
|
||||
wrapper.size = 1024 * 10
|
||||
startTime = clock.seconds()
|
||||
clock.advance(2.0)
|
||||
wrapper.total += 4096
|
||||
self.client._printProgressBar(wrapper, startTime)
|
||||
self.assertEqual(self.client.transport.value(),
|
||||
"\rsample 40% 4.0kB 2.0kBps 00:03 ")
|
||||
|
||||
|
||||
def test_reportNoProgress(self):
|
||||
"""
|
||||
L{StdioClient._printProgressBar} prints a progress description that
|
||||
indicates 0 bytes transferred if no bytes have been transferred and no
|
||||
time has passed.
|
||||
"""
|
||||
self.setKnownConsoleSize(10, 34)
|
||||
clock = self.client.reactor = Clock()
|
||||
wrapped = StringIO("x")
|
||||
wrapped.name = "sample"
|
||||
wrapper = cftp.FileWrapper(wrapped)
|
||||
startTime = clock.seconds()
|
||||
self.client._printProgressBar(wrapper, startTime)
|
||||
self.assertEqual(self.client.transport.value(),
|
||||
"\rsample 0% 0.0B 0.0Bps 00:00 ")
|
||||
|
||||
|
||||
|
||||
class FileTransferTestRealm:
|
||||
def __init__(self, testDir):
|
||||
self.testDir = testDir
|
||||
|
||||
def requestAvatar(self, avatarID, mind, *interfaces):
|
||||
a = FileTransferTestAvatar(self.testDir)
|
||||
return interfaces[0], a, lambda: None
|
||||
|
||||
|
||||
class SFTPTestProcess(protocol.ProcessProtocol):
|
||||
"""
|
||||
Protocol for testing cftp. Provides an interface between Python (where all
|
||||
the tests are) and the cftp client process (which does the work that is
|
||||
being tested).
|
||||
"""
|
||||
|
||||
def __init__(self, onOutReceived):
|
||||
"""
|
||||
@param onOutReceived: A L{Deferred} to be fired as soon as data is
|
||||
received from stdout.
|
||||
"""
|
||||
self.clearBuffer()
|
||||
self.onOutReceived = onOutReceived
|
||||
self.onProcessEnd = None
|
||||
self._expectingCommand = None
|
||||
self._processEnded = False
|
||||
|
||||
def clearBuffer(self):
|
||||
"""
|
||||
Clear any buffered data received from stdout. Should be private.
|
||||
"""
|
||||
self.buffer = ''
|
||||
self._linesReceived = []
|
||||
self._lineBuffer = ''
|
||||
|
||||
def outReceived(self, data):
|
||||
"""
|
||||
Called by Twisted when the cftp client prints data to stdout.
|
||||
"""
|
||||
log.msg('got %s' % data)
|
||||
lines = (self._lineBuffer + data).split('\n')
|
||||
self._lineBuffer = lines.pop(-1)
|
||||
self._linesReceived.extend(lines)
|
||||
# XXX - not strictly correct.
|
||||
# We really want onOutReceived to fire after the first 'cftp>' prompt
|
||||
# has been received. (See use in OurServerCmdLineClientTests.setUp)
|
||||
if self.onOutReceived is not None:
|
||||
d, self.onOutReceived = self.onOutReceived, None
|
||||
d.callback(data)
|
||||
self.buffer += data
|
||||
self._checkForCommand()
|
||||
|
||||
def _checkForCommand(self):
|
||||
prompt = 'cftp> '
|
||||
if self._expectingCommand and self._lineBuffer == prompt:
|
||||
buf = '\n'.join(self._linesReceived)
|
||||
if buf.startswith(prompt):
|
||||
buf = buf[len(prompt):]
|
||||
self.clearBuffer()
|
||||
d, self._expectingCommand = self._expectingCommand, None
|
||||
d.callback(buf)
|
||||
|
||||
def errReceived(self, data):
|
||||
"""
|
||||
Called by Twisted when the cftp client prints data to stderr.
|
||||
"""
|
||||
log.msg('err: %s' % data)
|
||||
|
||||
def getBuffer(self):
|
||||
"""
|
||||
Return the contents of the buffer of data received from stdout.
|
||||
"""
|
||||
return self.buffer
|
||||
|
||||
def runCommand(self, command):
|
||||
"""
|
||||
Issue the given command via the cftp client. Return a C{Deferred} that
|
||||
fires when the server returns a result. Note that the C{Deferred} will
|
||||
callback even if the server returns some kind of error.
|
||||
|
||||
@param command: A string containing an sftp command.
|
||||
|
||||
@return: A C{Deferred} that fires when the sftp server returns a
|
||||
result. The payload is the server's response string.
|
||||
"""
|
||||
self._expectingCommand = defer.Deferred()
|
||||
self.clearBuffer()
|
||||
self.transport.write(command + '\n')
|
||||
return self._expectingCommand
|
||||
|
||||
def runScript(self, commands):
|
||||
"""
|
||||
Run each command in sequence and return a Deferred that fires when all
|
||||
commands are completed.
|
||||
|
||||
@param commands: A list of strings containing sftp commands.
|
||||
|
||||
@return: A C{Deferred} that fires when all commands are completed. The
|
||||
payload is a list of response strings from the server, in the same
|
||||
order as the commands.
|
||||
"""
|
||||
sem = defer.DeferredSemaphore(1)
|
||||
dl = [sem.run(self.runCommand, command) for command in commands]
|
||||
return defer.gatherResults(dl)
|
||||
|
||||
def killProcess(self):
|
||||
"""
|
||||
Kill the process if it is still running.
|
||||
|
||||
If the process is still running, sends a KILL signal to the transport
|
||||
and returns a C{Deferred} which fires when L{processEnded} is called.
|
||||
|
||||
@return: a C{Deferred}.
|
||||
"""
|
||||
if self._processEnded:
|
||||
return defer.succeed(None)
|
||||
self.onProcessEnd = defer.Deferred()
|
||||
self.transport.signalProcess('KILL')
|
||||
return self.onProcessEnd
|
||||
|
||||
def processEnded(self, reason):
|
||||
"""
|
||||
Called by Twisted when the cftp client process ends.
|
||||
"""
|
||||
self._processEnded = True
|
||||
if self.onProcessEnd:
|
||||
d, self.onProcessEnd = self.onProcessEnd, None
|
||||
d.callback(None)
|
||||
|
||||
|
||||
class CFTPClientTestBase(SFTPTestBase):
|
||||
def setUp(self):
|
||||
f = open('dsa_test.pub','w')
|
||||
f.write(test_ssh.publicDSA_openssh)
|
||||
f.close()
|
||||
f = open('dsa_test','w')
|
||||
f.write(test_ssh.privateDSA_openssh)
|
||||
f.close()
|
||||
os.chmod('dsa_test', 33152)
|
||||
f = open('kh_test','w')
|
||||
f.write('127.0.0.1 ' + test_ssh.publicRSA_openssh)
|
||||
f.close()
|
||||
return SFTPTestBase.setUp(self)
|
||||
|
||||
def startServer(self):
|
||||
realm = FileTransferTestRealm(self.testDir)
|
||||
p = portal.Portal(realm)
|
||||
p.registerChecker(test_ssh.conchTestPublicKeyChecker())
|
||||
fac = test_ssh.ConchTestServerFactory()
|
||||
fac.portal = p
|
||||
self.server = reactor.listenTCP(0, fac, interface="127.0.0.1")
|
||||
|
||||
def stopServer(self):
|
||||
if not hasattr(self.server.factory, 'proto'):
|
||||
return self._cbStopServer(None)
|
||||
self.server.factory.proto.expectedLoseConnection = 1
|
||||
d = defer.maybeDeferred(
|
||||
self.server.factory.proto.transport.loseConnection)
|
||||
d.addCallback(self._cbStopServer)
|
||||
return d
|
||||
|
||||
def _cbStopServer(self, ignored):
|
||||
return defer.maybeDeferred(self.server.stopListening)
|
||||
|
||||
def tearDown(self):
|
||||
for f in ['dsa_test.pub', 'dsa_test', 'kh_test']:
|
||||
try:
|
||||
os.remove(f)
|
||||
except:
|
||||
pass
|
||||
return SFTPTestBase.tearDown(self)
|
||||
|
||||
|
||||
|
||||
class OurServerCmdLineClientTests(CFTPClientTestBase):
|
||||
|
||||
def setUp(self):
|
||||
CFTPClientTestBase.setUp(self)
|
||||
|
||||
self.startServer()
|
||||
cmds = ('-p %i -l testuser '
|
||||
'--known-hosts kh_test '
|
||||
'--user-authentications publickey '
|
||||
'--host-key-algorithms ssh-rsa '
|
||||
'-i dsa_test '
|
||||
'-a '
|
||||
'-v '
|
||||
'127.0.0.1')
|
||||
port = self.server.getHost().port
|
||||
cmds = test_conch._makeArgs((cmds % port).split(), mod='cftp')
|
||||
log.msg('running %s %s' % (sys.executable, cmds))
|
||||
d = defer.Deferred()
|
||||
self.processProtocol = SFTPTestProcess(d)
|
||||
d.addCallback(lambda _: self.processProtocol.clearBuffer())
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = os.pathsep.join(sys.path)
|
||||
reactor.spawnProcess(self.processProtocol, sys.executable, cmds,
|
||||
env=env)
|
||||
return d
|
||||
|
||||
def tearDown(self):
|
||||
d = self.stopServer()
|
||||
d.addCallback(lambda _: self.processProtocol.killProcess())
|
||||
return d
|
||||
|
||||
def _killProcess(self, ignored):
|
||||
try:
|
||||
self.processProtocol.transport.signalProcess('KILL')
|
||||
except error.ProcessExitedAlready:
|
||||
pass
|
||||
|
||||
def runCommand(self, command):
|
||||
"""
|
||||
Run the given command with the cftp client. Return a C{Deferred} that
|
||||
fires when the command is complete. Payload is the server's output for
|
||||
that command.
|
||||
"""
|
||||
return self.processProtocol.runCommand(command)
|
||||
|
||||
def runScript(self, *commands):
|
||||
"""
|
||||
Run the given commands with the cftp client. Returns a C{Deferred}
|
||||
that fires when the commands are all complete. The C{Deferred}'s
|
||||
payload is a list of output for each command.
|
||||
"""
|
||||
return self.processProtocol.runScript(commands)
|
||||
|
||||
def testCdPwd(self):
|
||||
"""
|
||||
Test that 'pwd' reports the current remote directory, that 'lpwd'
|
||||
reports the current local directory, and that changing to a
|
||||
subdirectory then changing to its parent leaves you in the original
|
||||
remote directory.
|
||||
"""
|
||||
# XXX - not actually a unit test, see docstring.
|
||||
homeDir = os.path.join(os.getcwd(), self.testDir)
|
||||
d = self.runScript('pwd', 'lpwd', 'cd testDirectory', 'cd ..', 'pwd')
|
||||
d.addCallback(lambda xs: xs[:3] + xs[4:])
|
||||
d.addCallback(self.assertEqual,
|
||||
[homeDir, os.getcwd(), '', homeDir])
|
||||
return d
|
||||
|
||||
def testChAttrs(self):
|
||||
"""
|
||||
Check that 'ls -l' output includes the access permissions and that
|
||||
this output changes appropriately with 'chmod'.
|
||||
"""
|
||||
def _check(results):
|
||||
self.flushLoggedErrors()
|
||||
self.assertTrue(results[0].startswith('-rw-r--r--'))
|
||||
self.assertEqual(results[1], '')
|
||||
self.assertTrue(results[2].startswith('----------'), results[2])
|
||||
self.assertEqual(results[3], '')
|
||||
|
||||
d = self.runScript('ls -l testfile1', 'chmod 0 testfile1',
|
||||
'ls -l testfile1', 'chmod 644 testfile1')
|
||||
return d.addCallback(_check)
|
||||
# XXX test chgrp/own
|
||||
|
||||
|
||||
def testList(self):
|
||||
"""
|
||||
Check 'ls' works as expected. Checks for wildcards, hidden files,
|
||||
listing directories and listing empty directories.
|
||||
"""
|
||||
def _check(results):
|
||||
self.assertEqual(results[0], ['testDirectory', 'testRemoveFile',
|
||||
'testRenameFile', 'testfile1'])
|
||||
self.assertEqual(results[1], ['testDirectory', 'testRemoveFile',
|
||||
'testRenameFile', 'testfile1'])
|
||||
self.assertEqual(results[2], ['testRemoveFile', 'testRenameFile'])
|
||||
self.assertEqual(results[3], ['.testHiddenFile', 'testRemoveFile',
|
||||
'testRenameFile'])
|
||||
self.assertEqual(results[4], [''])
|
||||
d = self.runScript('ls', 'ls ../' + os.path.basename(self.testDir),
|
||||
'ls *File', 'ls -a *File', 'ls -l testDirectory')
|
||||
d.addCallback(lambda xs: [x.split('\n') for x in xs])
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def testHelp(self):
|
||||
"""
|
||||
Check that running the '?' command returns help.
|
||||
"""
|
||||
d = self.runCommand('?')
|
||||
d.addCallback(self.assertEqual,
|
||||
cftp.StdioClient(None).cmd_HELP('').strip())
|
||||
return d
|
||||
|
||||
def assertFilesEqual(self, name1, name2, msg=None):
|
||||
"""
|
||||
Assert that the files at C{name1} and C{name2} contain exactly the
|
||||
same data.
|
||||
"""
|
||||
f1 = file(name1).read()
|
||||
f2 = file(name2).read()
|
||||
self.assertEqual(f1, f2, msg)
|
||||
|
||||
|
||||
def testGet(self):
|
||||
"""
|
||||
Test that 'get' saves the remote file to the correct local location,
|
||||
that the output of 'get' is correct and that 'rm' actually removes
|
||||
the file.
|
||||
"""
|
||||
# XXX - not actually a unit test
|
||||
expectedOutput = ("Transferred %s/%s/testfile1 to %s/test file2"
|
||||
% (os.getcwd(), self.testDir, self.testDir))
|
||||
def _checkGet(result):
|
||||
self.assertTrue(result.endswith(expectedOutput))
|
||||
self.assertFilesEqual(self.testDir + '/testfile1',
|
||||
self.testDir + '/test file2',
|
||||
"get failed")
|
||||
return self.runCommand('rm "test file2"')
|
||||
|
||||
d = self.runCommand('get testfile1 "%s/test file2"' % (self.testDir,))
|
||||
d.addCallback(_checkGet)
|
||||
d.addCallback(lambda _: self.assertFalse(
|
||||
os.path.exists(self.testDir + '/test file2')))
|
||||
return d
|
||||
|
||||
|
||||
def testWildcardGet(self):
|
||||
"""
|
||||
Test that 'get' works correctly when given wildcard parameters.
|
||||
"""
|
||||
def _check(ignored):
|
||||
self.assertFilesEqual(self.testDir + '/testRemoveFile',
|
||||
'testRemoveFile',
|
||||
'testRemoveFile get failed')
|
||||
self.assertFilesEqual(self.testDir + '/testRenameFile',
|
||||
'testRenameFile',
|
||||
'testRenameFile get failed')
|
||||
|
||||
d = self.runCommand('get testR*')
|
||||
return d.addCallback(_check)
|
||||
|
||||
|
||||
def testPut(self):
|
||||
"""
|
||||
Check that 'put' uploads files correctly and that they can be
|
||||
successfully removed. Also check the output of the put command.
|
||||
"""
|
||||
# XXX - not actually a unit test
|
||||
expectedOutput = ('Transferred %s/testfile1 to %s/%s/test"file2'
|
||||
% (self.testDir, os.getcwd(), self.testDir))
|
||||
def _checkPut(result):
|
||||
self.assertFilesEqual(self.testDir + '/testfile1',
|
||||
self.testDir + '/test"file2')
|
||||
self.assertTrue(result.endswith(expectedOutput))
|
||||
return self.runCommand('rm "test\\"file2"')
|
||||
|
||||
d = self.runCommand('put %s/testfile1 "test\\"file2"'
|
||||
% (self.testDir,))
|
||||
d.addCallback(_checkPut)
|
||||
d.addCallback(lambda _: self.assertFalse(
|
||||
os.path.exists(self.testDir + '/test"file2')))
|
||||
return d
|
||||
|
||||
|
||||
def test_putOverLongerFile(self):
|
||||
"""
|
||||
Check that 'put' uploads files correctly when overwriting a longer
|
||||
file.
|
||||
"""
|
||||
# XXX - not actually a unit test
|
||||
f = file(os.path.join(self.testDir, 'shorterFile'), 'w')
|
||||
f.write("a")
|
||||
f.close()
|
||||
f = file(os.path.join(self.testDir, 'longerFile'), 'w')
|
||||
f.write("bb")
|
||||
f.close()
|
||||
def _checkPut(result):
|
||||
self.assertFilesEqual(self.testDir + '/shorterFile',
|
||||
self.testDir + '/longerFile')
|
||||
|
||||
d = self.runCommand('put %s/shorterFile longerFile'
|
||||
% (self.testDir,))
|
||||
d.addCallback(_checkPut)
|
||||
return d
|
||||
|
||||
|
||||
def test_putMultipleOverLongerFile(self):
|
||||
"""
|
||||
Check that 'put' uploads files correctly when overwriting a longer
|
||||
file and you use a wildcard to specify the files to upload.
|
||||
"""
|
||||
# XXX - not actually a unit test
|
||||
os.mkdir(os.path.join(self.testDir, 'dir'))
|
||||
f = file(os.path.join(self.testDir, 'dir', 'file'), 'w')
|
||||
f.write("a")
|
||||
f.close()
|
||||
f = file(os.path.join(self.testDir, 'file'), 'w')
|
||||
f.write("bb")
|
||||
f.close()
|
||||
def _checkPut(result):
|
||||
self.assertFilesEqual(self.testDir + '/dir/file',
|
||||
self.testDir + '/file')
|
||||
|
||||
d = self.runCommand('put %s/dir/*'
|
||||
% (self.testDir,))
|
||||
d.addCallback(_checkPut)
|
||||
return d
|
||||
|
||||
|
||||
def testWildcardPut(self):
|
||||
"""
|
||||
What happens if you issue a 'put' command and include a wildcard (i.e.
|
||||
'*') in parameter? Check that all files matching the wildcard are
|
||||
uploaded to the correct directory.
|
||||
"""
|
||||
def check(results):
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertEqual(results[2], '')
|
||||
self.assertFilesEqual(self.testDir + '/testRemoveFile',
|
||||
self.testDir + '/../testRemoveFile',
|
||||
'testRemoveFile get failed')
|
||||
self.assertFilesEqual(self.testDir + '/testRenameFile',
|
||||
self.testDir + '/../testRenameFile',
|
||||
'testRenameFile get failed')
|
||||
|
||||
d = self.runScript('cd ..',
|
||||
'put %s/testR*' % (self.testDir,),
|
||||
'cd %s' % os.path.basename(self.testDir))
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
|
||||
def testLink(self):
|
||||
"""
|
||||
Test that 'ln' creates a file which appears as a link in the output of
|
||||
'ls'. Check that removing the new file succeeds without output.
|
||||
"""
|
||||
def _check(results):
|
||||
self.flushLoggedErrors()
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertTrue(results[1].startswith('l'), 'link failed')
|
||||
return self.runCommand('rm testLink')
|
||||
|
||||
d = self.runScript('ln testLink testfile1', 'ls -l testLink')
|
||||
d.addCallback(_check)
|
||||
d.addCallback(self.assertEqual, '')
|
||||
return d
|
||||
|
||||
|
||||
def testRemoteDirectory(self):
|
||||
"""
|
||||
Test that we can create and remove directories with the cftp client.
|
||||
"""
|
||||
def _check(results):
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertTrue(results[1].startswith('d'))
|
||||
return self.runCommand('rmdir testMakeDirectory')
|
||||
|
||||
d = self.runScript('mkdir testMakeDirectory',
|
||||
'ls -l testMakeDirector?')
|
||||
d.addCallback(_check)
|
||||
d.addCallback(self.assertEqual, '')
|
||||
return d
|
||||
|
||||
|
||||
def test_existingRemoteDirectory(self):
|
||||
"""
|
||||
Test that a C{mkdir} on an existing directory fails with the
|
||||
appropriate error, and doesn't log an useless error server side.
|
||||
"""
|
||||
def _check(results):
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertEqual(results[1],
|
||||
'remote error 11: mkdir failed')
|
||||
|
||||
d = self.runScript('mkdir testMakeDirectory',
|
||||
'mkdir testMakeDirectory')
|
||||
d.addCallback(_check)
|
||||
return d
|
||||
|
||||
|
||||
def testLocalDirectory(self):
|
||||
"""
|
||||
Test that we can create a directory locally and remove it with the
|
||||
cftp client. This test works because the 'remote' server is running
|
||||
out of a local directory.
|
||||
"""
|
||||
d = self.runCommand('lmkdir %s/testLocalDirectory' % (self.testDir,))
|
||||
d.addCallback(self.assertEqual, '')
|
||||
d.addCallback(lambda _: self.runCommand('rmdir testLocalDirectory'))
|
||||
d.addCallback(self.assertEqual, '')
|
||||
return d
|
||||
|
||||
|
||||
def testRename(self):
|
||||
"""
|
||||
Test that we can rename a file.
|
||||
"""
|
||||
def _check(results):
|
||||
self.assertEqual(results[0], '')
|
||||
self.assertEqual(results[1], 'testfile2')
|
||||
return self.runCommand('rename testfile2 testfile1')
|
||||
|
||||
d = self.runScript('rename testfile1 testfile2', 'ls testfile?')
|
||||
d.addCallback(_check)
|
||||
d.addCallback(self.assertEqual, '')
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class OurServerBatchFileTests(CFTPClientTestBase):
|
||||
def setUp(self):
|
||||
CFTPClientTestBase.setUp(self)
|
||||
self.startServer()
|
||||
|
||||
def tearDown(self):
|
||||
CFTPClientTestBase.tearDown(self)
|
||||
return self.stopServer()
|
||||
|
||||
def _getBatchOutput(self, f):
|
||||
fn = self.mktemp()
|
||||
open(fn, 'w').write(f)
|
||||
port = self.server.getHost().port
|
||||
cmds = ('-p %i -l testuser '
|
||||
'--known-hosts kh_test '
|
||||
'--user-authentications publickey '
|
||||
'--host-key-algorithms ssh-rsa '
|
||||
'-i dsa_test '
|
||||
'-a '
|
||||
'-v -b %s 127.0.0.1') % (port, fn)
|
||||
cmds = test_conch._makeArgs(cmds.split(), mod='cftp')[1:]
|
||||
log.msg('running %s %s' % (sys.executable, cmds))
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = os.pathsep.join(sys.path)
|
||||
|
||||
self.server.factory.expectedLoseConnection = 1
|
||||
|
||||
d = getProcessOutputAndValue(sys.executable, cmds, env=env)
|
||||
|
||||
def _cleanup(res):
|
||||
os.remove(fn)
|
||||
return res
|
||||
|
||||
d.addCallback(lambda res: res[0])
|
||||
d.addBoth(_cleanup)
|
||||
|
||||
return d
|
||||
|
||||
def testBatchFile(self):
|
||||
"""Test whether batch file function of cftp ('cftp -b batchfile').
|
||||
This works by treating the file as a list of commands to be run.
|
||||
"""
|
||||
cmds = """pwd
|
||||
ls
|
||||
exit
|
||||
"""
|
||||
def _cbCheckResult(res):
|
||||
res = res.split('\n')
|
||||
log.msg('RES %s' % str(res))
|
||||
self.assertIn(self.testDir, res[1])
|
||||
self.assertEqual(res[3:-2], ['testDirectory', 'testRemoveFile',
|
||||
'testRenameFile', 'testfile1'])
|
||||
|
||||
d = self._getBatchOutput(cmds)
|
||||
d.addCallback(_cbCheckResult)
|
||||
return d
|
||||
|
||||
def testError(self):
|
||||
"""Test that an error in the batch file stops running the batch.
|
||||
"""
|
||||
cmds = """chown 0 missingFile
|
||||
pwd
|
||||
exit
|
||||
"""
|
||||
def _cbCheckResult(res):
|
||||
self.assertNotIn(self.testDir, res)
|
||||
|
||||
d = self._getBatchOutput(cmds)
|
||||
d.addCallback(_cbCheckResult)
|
||||
return d
|
||||
|
||||
def testIgnoredError(self):
|
||||
"""Test that a minus sign '-' at the front of a line ignores
|
||||
any errors.
|
||||
"""
|
||||
cmds = """-chown 0 missingFile
|
||||
pwd
|
||||
exit
|
||||
"""
|
||||
def _cbCheckResult(res):
|
||||
self.assertIn(self.testDir, res)
|
||||
|
||||
d = self._getBatchOutput(cmds)
|
||||
d.addCallback(_cbCheckResult)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class OurServerSftpClientTests(CFTPClientTestBase):
|
||||
"""
|
||||
Test the sftp server against sftp command line client.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
CFTPClientTestBase.setUp(self)
|
||||
return self.startServer()
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
return self.stopServer()
|
||||
|
||||
|
||||
def test_extendedAttributes(self):
|
||||
"""
|
||||
Test the return of extended attributes by the server: the sftp client
|
||||
should ignore them, but still be able to parse the response correctly.
|
||||
|
||||
This test is mainly here to check that
|
||||
L{filetransfer.FILEXFER_ATTR_EXTENDED} has the correct value.
|
||||
"""
|
||||
fn = self.mktemp()
|
||||
open(fn, 'w').write("ls .\nexit")
|
||||
port = self.server.getHost().port
|
||||
|
||||
oldGetAttr = FileTransferForTestAvatar._getAttrs
|
||||
def _getAttrs(self, s):
|
||||
attrs = oldGetAttr(self, s)
|
||||
attrs["ext_foo"] = "bar"
|
||||
return attrs
|
||||
|
||||
self.patch(FileTransferForTestAvatar, "_getAttrs", _getAttrs)
|
||||
|
||||
self.server.factory.expectedLoseConnection = True
|
||||
cmds = ('-o', 'IdentityFile=dsa_test',
|
||||
'-o', 'UserKnownHostsFile=kh_test',
|
||||
'-o', 'HostKeyAlgorithms=ssh-rsa',
|
||||
'-o', 'Port=%i' % (port,), '-b', fn, 'testuser@127.0.0.1')
|
||||
d = getProcessOutputAndValue("sftp", cmds)
|
||||
def check(result):
|
||||
self.assertEqual(result[2], 0)
|
||||
for i in ['testDirectory', 'testRemoveFile',
|
||||
'testRenameFile', 'testfile1']:
|
||||
self.assertIn(i, result[0])
|
||||
return d.addCallback(check)
|
||||
|
||||
|
||||
|
||||
if unix is None or Crypto is None or pyasn1 is None or interfaces.IReactorProcess(reactor, None) is None:
|
||||
if _reason is None:
|
||||
_reason = "don't run w/o spawnProcess or PyCrypto or pyasn1"
|
||||
OurServerCmdLineClientTests.skip = _reason
|
||||
OurServerBatchFileTests.skip = _reason
|
||||
OurServerSftpClientTests.skip = _reason
|
||||
StdioClientTests.skip = _reason
|
||||
SSHSessionTests.skip = _reason
|
||||
else:
|
||||
from twisted.python.procutils import which
|
||||
if not which('sftp'):
|
||||
OurServerSftpClientTests.skip = "no sftp command-line client available"
|
@ -0,0 +1,279 @@
|
||||
# Copyright (C) 2007-2008 Twisted Matrix Laboratories
|
||||
# See LICENSE for details
|
||||
|
||||
"""
|
||||
Test ssh/channel.py.
|
||||
"""
|
||||
from twisted.conch.ssh import channel
|
||||
from twisted.trial import unittest
|
||||
|
||||
|
||||
class MockTransport(object):
|
||||
"""
|
||||
A mock Transport. All we use is the getPeer() and getHost() methods.
|
||||
Channels implement the ITransport interface, and their getPeer() and
|
||||
getHost() methods return ('SSH', <transport's getPeer/Host value>) so
|
||||
we need to implement these methods so they have something to draw
|
||||
from.
|
||||
"""
|
||||
def getPeer(self):
|
||||
return ('MockPeer',)
|
||||
|
||||
def getHost(self):
|
||||
return ('MockHost',)
|
||||
|
||||
|
||||
class MockConnection(object):
|
||||
"""
|
||||
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
|
||||
that channels send, and when they try to close the connection.
|
||||
|
||||
@ivar data: a C{dict} mapping channel id #s to lists of data sent by that
|
||||
channel.
|
||||
@ivar extData: a C{dict} mapping channel id #s to lists of 2-tuples
|
||||
(extended data type, data) sent by that channel.
|
||||
@ivar closes: a C{dict} mapping channel id #s to True if that channel sent
|
||||
a close message.
|
||||
"""
|
||||
transport = MockTransport()
|
||||
|
||||
def __init__(self):
|
||||
self.data = {}
|
||||
self.extData = {}
|
||||
self.closes = {}
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Return our logging prefix.
|
||||
"""
|
||||
return "MockConnection"
|
||||
|
||||
def sendData(self, channel, data):
|
||||
"""
|
||||
Record the sent data.
|
||||
"""
|
||||
self.data.setdefault(channel, []).append(data)
|
||||
|
||||
def sendExtendedData(self, channel, type, data):
|
||||
"""
|
||||
Record the sent extended data.
|
||||
"""
|
||||
self.extData.setdefault(channel, []).append((type, data))
|
||||
|
||||
def sendClose(self, channel):
|
||||
"""
|
||||
Record that the channel sent a close message.
|
||||
"""
|
||||
self.closes[channel] = True
|
||||
|
||||
|
||||
class ChannelTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Initialize the channel. remoteMaxPacket is 10 so that data is able
|
||||
to be sent (the default of 0 means no data is sent because no packets
|
||||
are made).
|
||||
"""
|
||||
self.conn = MockConnection()
|
||||
self.channel = channel.SSHChannel(conn=self.conn,
|
||||
remoteMaxPacket=10)
|
||||
self.channel.name = 'channel'
|
||||
|
||||
def test_init(self):
|
||||
"""
|
||||
Test that SSHChannel initializes correctly. localWindowSize defaults
|
||||
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
|
||||
defaults (what OpenSSH uses for those variables).
|
||||
|
||||
The values in the second set of assertions are meaningless; they serve
|
||||
only to verify that the instance variables are assigned in the correct
|
||||
order.
|
||||
"""
|
||||
c = channel.SSHChannel(conn=self.conn)
|
||||
self.assertEqual(c.localWindowSize, 131072)
|
||||
self.assertEqual(c.localWindowLeft, 131072)
|
||||
self.assertEqual(c.localMaxPacket, 32768)
|
||||
self.assertEqual(c.remoteWindowLeft, 0)
|
||||
self.assertEqual(c.remoteMaxPacket, 0)
|
||||
self.assertEqual(c.conn, self.conn)
|
||||
self.assertEqual(c.data, None)
|
||||
self.assertEqual(c.avatar, None)
|
||||
|
||||
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
|
||||
self.assertEqual(c2.localWindowSize, 1)
|
||||
self.assertEqual(c2.localWindowLeft, 1)
|
||||
self.assertEqual(c2.localMaxPacket, 2)
|
||||
self.assertEqual(c2.remoteWindowLeft, 3)
|
||||
self.assertEqual(c2.remoteMaxPacket, 4)
|
||||
self.assertEqual(c2.conn, 5)
|
||||
self.assertEqual(c2.data, 6)
|
||||
self.assertEqual(c2.avatar, 7)
|
||||
|
||||
def test_str(self):
|
||||
"""
|
||||
Test that str(SSHChannel) works gives the channel name and local and
|
||||
remote windows at a glance..
|
||||
"""
|
||||
self.assertEqual(str(self.channel), '<SSHChannel channel (lw 131072 '
|
||||
'rw 0)>')
|
||||
|
||||
def test_logPrefix(self):
|
||||
"""
|
||||
Test that SSHChannel.logPrefix gives the name of the channel, the
|
||||
local channel ID and the underlying connection.
|
||||
"""
|
||||
self.assertEqual(self.channel.logPrefix(), 'SSHChannel channel '
|
||||
'(unknown) on MockConnection')
|
||||
|
||||
def test_addWindowBytes(self):
|
||||
"""
|
||||
Test that addWindowBytes adds bytes to the window and resumes writing
|
||||
if it was paused.
|
||||
"""
|
||||
cb = [False]
|
||||
def stubStartWriting():
|
||||
cb[0] = True
|
||||
self.channel.startWriting = stubStartWriting
|
||||
self.channel.write('test')
|
||||
self.channel.writeExtended(1, 'test')
|
||||
self.channel.addWindowBytes(50)
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
|
||||
self.assertTrue(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
self.assertEqual(self.channel.buf, '')
|
||||
self.assertEqual(self.conn.data[self.channel], ['test'])
|
||||
self.assertEqual(self.channel.extBuf, [])
|
||||
self.assertEqual(self.conn.extData[self.channel], [(1, 'test')])
|
||||
|
||||
cb[0] = False
|
||||
self.channel.addWindowBytes(20)
|
||||
self.assertFalse(cb[0])
|
||||
|
||||
self.channel.write('a'*80)
|
||||
self.channel.loseConnection()
|
||||
self.channel.addWindowBytes(20)
|
||||
self.assertFalse(cb[0])
|
||||
|
||||
def test_requestReceived(self):
|
||||
"""
|
||||
Test that requestReceived handles requests by dispatching them to
|
||||
request_* methods.
|
||||
"""
|
||||
self.channel.request_test_method = lambda data: data == ''
|
||||
self.assertTrue(self.channel.requestReceived('test-method', ''))
|
||||
self.assertFalse(self.channel.requestReceived('test-method', 'a'))
|
||||
self.assertFalse(self.channel.requestReceived('bad-method', ''))
|
||||
|
||||
def test_closeReceieved(self):
|
||||
"""
|
||||
Test that the default closeReceieved closes the connection.
|
||||
"""
|
||||
self.assertFalse(self.channel.closing)
|
||||
self.channel.closeReceived()
|
||||
self.assertTrue(self.channel.closing)
|
||||
|
||||
def test_write(self):
|
||||
"""
|
||||
Test that write handles data correctly. Send data up to the size
|
||||
of the remote window, splitting the data into packets of length
|
||||
remoteMaxPacket.
|
||||
"""
|
||||
cb = [False]
|
||||
def stubStopWriting():
|
||||
cb[0] = True
|
||||
# no window to start with
|
||||
self.channel.stopWriting = stubStopWriting
|
||||
self.channel.write('d')
|
||||
self.channel.write('a')
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
# regular write
|
||||
self.channel.addWindowBytes(20)
|
||||
self.channel.write('ta')
|
||||
data = self.conn.data[self.channel]
|
||||
self.assertEqual(data, ['da', 'ta'])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 16)
|
||||
# larger than max packet
|
||||
self.channel.write('12345678901')
|
||||
self.assertEqual(data, ['da', 'ta', '1234567890', '1'])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 5)
|
||||
# running out of window
|
||||
cb[0] = False
|
||||
self.channel.write('123456')
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
self.assertEqual(data, ['da', 'ta', '1234567890', '1', '12345'])
|
||||
self.assertEqual(self.channel.buf, '6')
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 0)
|
||||
|
||||
def test_writeExtended(self):
|
||||
"""
|
||||
Test that writeExtended handles data correctly. Send extended data
|
||||
up to the size of the window, splitting the extended data into packets
|
||||
of length remoteMaxPacket.
|
||||
"""
|
||||
cb = [False]
|
||||
def stubStopWriting():
|
||||
cb[0] = True
|
||||
# no window to start with
|
||||
self.channel.stopWriting = stubStopWriting
|
||||
self.channel.writeExtended(1, 'd')
|
||||
self.channel.writeExtended(1, 'a')
|
||||
self.channel.writeExtended(2, 't')
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
# regular write
|
||||
self.channel.addWindowBytes(20)
|
||||
self.channel.writeExtended(2, 'a')
|
||||
data = self.conn.extData[self.channel]
|
||||
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a')])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 16)
|
||||
# larger than max packet
|
||||
self.channel.writeExtended(3, '12345678901')
|
||||
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
|
||||
(3, '1234567890'), (3, '1')])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 5)
|
||||
# running out of window
|
||||
cb[0] = False
|
||||
self.channel.writeExtended(4, '123456')
|
||||
self.assertFalse(self.channel.areWriting)
|
||||
self.assertTrue(cb[0])
|
||||
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
|
||||
(3, '1234567890'), (3, '1'), (4, '12345')])
|
||||
self.assertEqual(self.channel.extBuf, [[4, '6']])
|
||||
self.assertEqual(self.channel.remoteWindowLeft, 0)
|
||||
|
||||
def test_writeSequence(self):
|
||||
"""
|
||||
Test that writeSequence is equivalent to write(''.join(sequece)).
|
||||
"""
|
||||
self.channel.addWindowBytes(20)
|
||||
self.channel.writeSequence(map(str, range(10)))
|
||||
self.assertEqual(self.conn.data[self.channel], ['0123456789'])
|
||||
|
||||
def test_loseConnection(self):
|
||||
"""
|
||||
Tesyt that loseConnection() doesn't close the channel until all
|
||||
the data is sent.
|
||||
"""
|
||||
self.channel.write('data')
|
||||
self.channel.writeExtended(1, 'datadata')
|
||||
self.channel.loseConnection()
|
||||
self.assertEqual(self.conn.closes.get(self.channel), None)
|
||||
self.channel.addWindowBytes(4) # send regular data
|
||||
self.assertEqual(self.conn.closes.get(self.channel), None)
|
||||
self.channel.addWindowBytes(8) # send extended data
|
||||
self.assertTrue(self.conn.closes.get(self.channel))
|
||||
|
||||
def test_getPeer(self):
|
||||
"""
|
||||
Test that getPeer() returns ('SSH', <connection transport peer>).
|
||||
"""
|
||||
self.assertEqual(self.channel.getPeer(), ('SSH', 'MockPeer'))
|
||||
|
||||
def test_getHost(self):
|
||||
"""
|
||||
Test that getHost() returns ('SSH', <connection transport host>).
|
||||
"""
|
||||
self.assertEqual(self.channel.getHost(), ('SSH', 'MockHost'))
|
@ -0,0 +1,892 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.checkers}.
|
||||
"""
|
||||
|
||||
try:
|
||||
import crypt
|
||||
except ImportError:
|
||||
cryptSkip = 'cannot run without crypt module'
|
||||
else:
|
||||
cryptSkip = None
|
||||
|
||||
import os, base64
|
||||
from collections import namedtuple
|
||||
from io import StringIO
|
||||
|
||||
from zope.interface.verify import verifyObject
|
||||
|
||||
from twisted.python import util
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.python.reflect import requireModule
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
|
||||
from twisted.cred.credentials import UsernamePassword, IUsernamePassword, \
|
||||
SSHPrivateKey, ISSHPrivateKey
|
||||
from twisted.cred.error import UnhandledCredentials, UnauthorizedLogin
|
||||
from twisted.python.fakepwd import UserDatabase, ShadowDatabase
|
||||
from twisted.test.test_process import MockOS
|
||||
|
||||
if requireModule('Crypto.Cipher.DES3') and requireModule('pyasn1'):
|
||||
dependencySkip = None
|
||||
from twisted.conch.ssh import keys
|
||||
from twisted.conch import checkers
|
||||
from twisted.conch.error import NotEnoughAuthentication, ValidPublicKey
|
||||
from twisted.conch.test import keydata
|
||||
else:
|
||||
dependencySkip = "can't run without Crypto and PyASN1"
|
||||
|
||||
if getattr(os, 'geteuid', None) is None:
|
||||
euidSkip = "Cannot run without effective UIDs (questionable)"
|
||||
else:
|
||||
euidSkip = None
|
||||
|
||||
|
||||
class HelperTests(TestCase):
|
||||
"""
|
||||
Tests for helper functions L{verifyCryptedPassword}, L{_pwdGetByName} and
|
||||
L{_shadowGetByName}.
|
||||
"""
|
||||
skip = cryptSkip or dependencySkip
|
||||
|
||||
def setUp(self):
|
||||
self.mockos = MockOS()
|
||||
|
||||
|
||||
def test_verifyCryptedPassword(self):
|
||||
"""
|
||||
L{verifyCryptedPassword} returns C{True} if the plaintext password
|
||||
passed to it matches the encrypted password passed to it.
|
||||
"""
|
||||
password = 'secret string'
|
||||
salt = 'salty'
|
||||
crypted = crypt.crypt(password, salt)
|
||||
self.assertTrue(
|
||||
checkers.verifyCryptedPassword(crypted, password),
|
||||
'%r supposed to be valid encrypted password for %r' % (
|
||||
crypted, password))
|
||||
|
||||
|
||||
def test_verifyCryptedPasswordMD5(self):
|
||||
"""
|
||||
L{verifyCryptedPassword} returns True if the provided cleartext password
|
||||
matches the provided MD5 password hash.
|
||||
"""
|
||||
password = 'password'
|
||||
salt = '$1$salt'
|
||||
crypted = crypt.crypt(password, salt)
|
||||
self.assertTrue(
|
||||
checkers.verifyCryptedPassword(crypted, password),
|
||||
'%r supposed to be valid encrypted password for %s' % (
|
||||
crypted, password))
|
||||
|
||||
|
||||
def test_refuteCryptedPassword(self):
|
||||
"""
|
||||
L{verifyCryptedPassword} returns C{False} if the plaintext password
|
||||
passed to it does not match the encrypted password passed to it.
|
||||
"""
|
||||
password = 'string secret'
|
||||
wrong = 'secret string'
|
||||
crypted = crypt.crypt(password, password)
|
||||
self.assertFalse(
|
||||
checkers.verifyCryptedPassword(crypted, wrong),
|
||||
'%r not supposed to be valid encrypted password for %s' % (
|
||||
crypted, wrong))
|
||||
|
||||
|
||||
def test_pwdGetByName(self):
|
||||
"""
|
||||
L{_pwdGetByName} returns a tuple of items from the UNIX /etc/passwd
|
||||
database if the L{pwd} module is present.
|
||||
"""
|
||||
userdb = UserDatabase()
|
||||
userdb.addUser(
|
||||
'alice', 'secrit', 1, 2, 'first last', '/foo', '/bin/sh')
|
||||
self.patch(checkers, 'pwd', userdb)
|
||||
self.assertEqual(
|
||||
checkers._pwdGetByName('alice'), userdb.getpwnam('alice'))
|
||||
|
||||
|
||||
def test_pwdGetByNameWithoutPwd(self):
|
||||
"""
|
||||
If the C{pwd} module isn't present, L{_pwdGetByName} returns C{None}.
|
||||
"""
|
||||
self.patch(checkers, 'pwd', None)
|
||||
self.assertIs(checkers._pwdGetByName('alice'), None)
|
||||
|
||||
|
||||
def test_shadowGetByName(self):
|
||||
"""
|
||||
L{_shadowGetByName} returns a tuple of items from the UNIX /etc/shadow
|
||||
database if the L{spwd} is present.
|
||||
"""
|
||||
userdb = ShadowDatabase()
|
||||
userdb.addUser('bob', 'passphrase', 1, 2, 3, 4, 5, 6, 7)
|
||||
self.patch(checkers, 'spwd', userdb)
|
||||
|
||||
self.mockos.euid = 2345
|
||||
self.mockos.egid = 1234
|
||||
self.patch(util, 'os', self.mockos)
|
||||
|
||||
self.assertEqual(
|
||||
checkers._shadowGetByName('bob'), userdb.getspnam('bob'))
|
||||
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
|
||||
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
|
||||
|
||||
|
||||
def test_shadowGetByNameWithoutSpwd(self):
|
||||
"""
|
||||
L{_shadowGetByName} uses the C{shadow} module to return a tuple of items
|
||||
from the UNIX /etc/shadow database if the C{spwd} module is not present
|
||||
and the C{shadow} module is.
|
||||
"""
|
||||
userdb = ShadowDatabase()
|
||||
userdb.addUser('bob', 'passphrase', 1, 2, 3, 4, 5, 6, 7)
|
||||
self.patch(checkers, 'spwd', None)
|
||||
self.patch(checkers, 'shadow', userdb)
|
||||
self.patch(util, 'os', self.mockos)
|
||||
|
||||
self.mockos.euid = 2345
|
||||
self.mockos.egid = 1234
|
||||
|
||||
self.assertEqual(
|
||||
checkers._shadowGetByName('bob'), userdb.getspnam('bob'))
|
||||
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
|
||||
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
|
||||
|
||||
|
||||
def test_shadowGetByNameWithoutEither(self):
|
||||
"""
|
||||
L{_shadowGetByName} returns C{None} if neither C{spwd} nor C{shadow} is
|
||||
present.
|
||||
"""
|
||||
self.patch(checkers, 'spwd', None)
|
||||
self.patch(checkers, 'shadow', None)
|
||||
|
||||
self.assertIs(checkers._shadowGetByName('bob'), None)
|
||||
self.assertEqual(self.mockos.seteuidCalls, [])
|
||||
self.assertEqual(self.mockos.setegidCalls, [])
|
||||
|
||||
|
||||
|
||||
class SSHPublicKeyDatabaseTests(TestCase):
|
||||
"""
|
||||
Tests for L{SSHPublicKeyDatabase}.
|
||||
"""
|
||||
skip = euidSkip or dependencySkip
|
||||
|
||||
def setUp(self):
|
||||
self.checker = checkers.SSHPublicKeyDatabase()
|
||||
self.key1 = base64.encodestring("foobar")
|
||||
self.key2 = base64.encodestring("eggspam")
|
||||
self.content = "t1 %s foo\nt2 %s egg\n" % (self.key1, self.key2)
|
||||
|
||||
self.mockos = MockOS()
|
||||
self.mockos.path = FilePath(self.mktemp())
|
||||
self.mockos.path.makedirs()
|
||||
self.patch(util, 'os', self.mockos)
|
||||
self.sshDir = self.mockos.path.child('.ssh')
|
||||
self.sshDir.makedirs()
|
||||
|
||||
userdb = UserDatabase()
|
||||
userdb.addUser(
|
||||
'user', 'password', 1, 2, 'first last',
|
||||
self.mockos.path.path, '/bin/shell')
|
||||
self.checker._userdb = userdb
|
||||
|
||||
|
||||
def test_deprecated(self):
|
||||
"""
|
||||
L{SSHPublicKeyDatabase} is deprecated as of version 15.0
|
||||
"""
|
||||
warningsShown = self.flushWarnings(
|
||||
offendingFunctions=[self.setUp])
|
||||
self.assertEqual(warningsShown[0]['category'], DeprecationWarning)
|
||||
self.assertEqual(
|
||||
warningsShown[0]['message'],
|
||||
"twisted.conch.checkers.SSHPublicKeyDatabase "
|
||||
"was deprecated in Twisted 15.0.0: Please use "
|
||||
"twisted.conch.checkers.SSHPublicKeyChecker, "
|
||||
"initialized with an instance of "
|
||||
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead.")
|
||||
self.assertEqual(len(warningsShown), 1)
|
||||
|
||||
|
||||
def _testCheckKey(self, filename):
|
||||
self.sshDir.child(filename).setContent(self.content)
|
||||
user = UsernamePassword("user", "password")
|
||||
user.blob = "foobar"
|
||||
self.assertTrue(self.checker.checkKey(user))
|
||||
user.blob = "eggspam"
|
||||
self.assertTrue(self.checker.checkKey(user))
|
||||
user.blob = "notallowed"
|
||||
self.assertFalse(self.checker.checkKey(user))
|
||||
|
||||
|
||||
def test_checkKey(self):
|
||||
"""
|
||||
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
|
||||
authorized_keys file and check the keys against that file.
|
||||
"""
|
||||
self._testCheckKey("authorized_keys")
|
||||
self.assertEqual(self.mockos.seteuidCalls, [])
|
||||
self.assertEqual(self.mockos.setegidCalls, [])
|
||||
|
||||
|
||||
def test_checkKey2(self):
|
||||
"""
|
||||
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
|
||||
authorized_keys2 file and check the keys against that file.
|
||||
"""
|
||||
self._testCheckKey("authorized_keys2")
|
||||
self.assertEqual(self.mockos.seteuidCalls, [])
|
||||
self.assertEqual(self.mockos.setegidCalls, [])
|
||||
|
||||
|
||||
def test_checkKeyAsRoot(self):
|
||||
"""
|
||||
If the key file is readable, L{SSHPublicKeyDatabase.checkKey} should
|
||||
switch its uid/gid to the ones of the authenticated user.
|
||||
"""
|
||||
keyFile = self.sshDir.child("authorized_keys")
|
||||
keyFile.setContent(self.content)
|
||||
# Fake permission error by changing the mode
|
||||
keyFile.chmod(0000)
|
||||
self.addCleanup(keyFile.chmod, 0777)
|
||||
# And restore the right mode when seteuid is called
|
||||
savedSeteuid = self.mockos.seteuid
|
||||
def seteuid(euid):
|
||||
keyFile.chmod(0777)
|
||||
return savedSeteuid(euid)
|
||||
self.mockos.euid = 2345
|
||||
self.mockos.egid = 1234
|
||||
self.patch(self.mockos, "seteuid", seteuid)
|
||||
self.patch(util, 'os', self.mockos)
|
||||
user = UsernamePassword("user", "password")
|
||||
user.blob = "foobar"
|
||||
self.assertTrue(self.checker.checkKey(user))
|
||||
self.assertEqual(self.mockos.seteuidCalls, [0, 1, 0, 2345])
|
||||
self.assertEqual(self.mockos.setegidCalls, [2, 1234])
|
||||
|
||||
|
||||
def test_requestAvatarId(self):
|
||||
"""
|
||||
L{SSHPublicKeyDatabase.requestAvatarId} should return the avatar id
|
||||
passed in if its C{_checkKey} method returns True.
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return True
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
credentials = SSHPrivateKey(
|
||||
'test', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
|
||||
keys.Key.fromString(keydata.privateRSA_openssh).sign('foo'))
|
||||
d = self.checker.requestAvatarId(credentials)
|
||||
def _verify(avatarId):
|
||||
self.assertEqual(avatarId, 'test')
|
||||
return d.addCallback(_verify)
|
||||
|
||||
|
||||
def test_requestAvatarIdWithoutSignature(self):
|
||||
"""
|
||||
L{SSHPublicKeyDatabase.requestAvatarId} should raise L{ValidPublicKey}
|
||||
if the credentials represent a valid key without a signature. This
|
||||
tells the user that the key is valid for login, but does not actually
|
||||
allow that user to do so without a signature.
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return True
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
credentials = SSHPrivateKey(
|
||||
'test', 'ssh-rsa', keydata.publicRSA_openssh, None, None)
|
||||
d = self.checker.requestAvatarId(credentials)
|
||||
return self.assertFailure(d, ValidPublicKey)
|
||||
|
||||
|
||||
def test_requestAvatarIdInvalidKey(self):
|
||||
"""
|
||||
If L{SSHPublicKeyDatabase.checkKey} returns False,
|
||||
C{_cbRequestAvatarId} should raise L{UnauthorizedLogin}.
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return False
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
d = self.checker.requestAvatarId(None);
|
||||
return self.assertFailure(d, UnauthorizedLogin)
|
||||
|
||||
|
||||
def test_requestAvatarIdInvalidSignature(self):
|
||||
"""
|
||||
Valid keys with invalid signatures should cause
|
||||
L{SSHPublicKeyDatabase.requestAvatarId} to return a {UnauthorizedLogin}
|
||||
failure
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return True
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
credentials = SSHPrivateKey(
|
||||
'test', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
|
||||
keys.Key.fromString(keydata.privateDSA_openssh).sign('foo'))
|
||||
d = self.checker.requestAvatarId(credentials)
|
||||
return self.assertFailure(d, UnauthorizedLogin)
|
||||
|
||||
|
||||
def test_requestAvatarIdNormalizeException(self):
|
||||
"""
|
||||
Exceptions raised while verifying the key should be normalized into an
|
||||
C{UnauthorizedLogin} failure.
|
||||
"""
|
||||
def _checkKey(ignored):
|
||||
return True
|
||||
self.patch(self.checker, 'checkKey', _checkKey)
|
||||
credentials = SSHPrivateKey('test', None, 'blob', 'sigData', 'sig')
|
||||
d = self.checker.requestAvatarId(credentials)
|
||||
def _verifyLoggedException(failure):
|
||||
errors = self.flushLoggedErrors(keys.BadKeyError)
|
||||
self.assertEqual(len(errors), 1)
|
||||
return failure
|
||||
d.addErrback(_verifyLoggedException)
|
||||
return self.assertFailure(d, UnauthorizedLogin)
|
||||
|
||||
|
||||
|
||||
class SSHProtocolCheckerTests(TestCase):
|
||||
"""
|
||||
Tests for L{SSHProtocolChecker}.
|
||||
"""
|
||||
|
||||
skip = dependencySkip
|
||||
|
||||
def test_registerChecker(self):
|
||||
"""
|
||||
L{SSHProcotolChecker.registerChecker} should add the given checker to
|
||||
the list of registered checkers.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
self.assertEqual(checker.credentialInterfaces, [])
|
||||
checker.registerChecker(checkers.SSHPublicKeyDatabase(), )
|
||||
self.assertEqual(checker.credentialInterfaces, [ISSHPrivateKey])
|
||||
self.assertIsInstance(checker.checkers[ISSHPrivateKey],
|
||||
checkers.SSHPublicKeyDatabase)
|
||||
|
||||
|
||||
def test_registerCheckerWithInterface(self):
|
||||
"""
|
||||
If a apecific interface is passed into
|
||||
L{SSHProtocolChecker.registerChecker}, that interface should be
|
||||
registered instead of what the checker specifies in
|
||||
credentialIntefaces.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
self.assertEqual(checker.credentialInterfaces, [])
|
||||
checker.registerChecker(checkers.SSHPublicKeyDatabase(),
|
||||
IUsernamePassword)
|
||||
self.assertEqual(checker.credentialInterfaces, [IUsernamePassword])
|
||||
self.assertIsInstance(checker.checkers[IUsernamePassword],
|
||||
checkers.SSHPublicKeyDatabase)
|
||||
|
||||
|
||||
def test_requestAvatarId(self):
|
||||
"""
|
||||
L{SSHProtocolChecker.requestAvatarId} should defer to one if its
|
||||
registered checkers to authenticate a user.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
|
||||
passwordDatabase.addUser('test', 'test')
|
||||
checker.registerChecker(passwordDatabase)
|
||||
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
|
||||
def _callback(avatarId):
|
||||
self.assertEqual(avatarId, 'test')
|
||||
return d.addCallback(_callback)
|
||||
|
||||
|
||||
def test_requestAvatarIdWithNotEnoughAuthentication(self):
|
||||
"""
|
||||
If the client indicates that it is never satisfied, by always returning
|
||||
False from _areDone, then L{SSHProtocolChecker} should raise
|
||||
L{NotEnoughAuthentication}.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
def _areDone(avatarId):
|
||||
return False
|
||||
self.patch(checker, 'areDone', _areDone)
|
||||
|
||||
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
|
||||
passwordDatabase.addUser('test', 'test')
|
||||
checker.registerChecker(passwordDatabase)
|
||||
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
|
||||
return self.assertFailure(d, NotEnoughAuthentication)
|
||||
|
||||
|
||||
def test_requestAvatarIdInvalidCredential(self):
|
||||
"""
|
||||
If the passed credentials aren't handled by any registered checker,
|
||||
L{SSHProtocolChecker} should raise L{UnhandledCredentials}.
|
||||
"""
|
||||
checker = checkers.SSHProtocolChecker()
|
||||
d = checker.requestAvatarId(UsernamePassword('test', 'test'))
|
||||
return self.assertFailure(d, UnhandledCredentials)
|
||||
|
||||
|
||||
def test_areDone(self):
|
||||
"""
|
||||
The default L{SSHProcotolChecker.areDone} should simply return True.
|
||||
"""
|
||||
self.assertEqual(checkers.SSHProtocolChecker().areDone(None), True)
|
||||
|
||||
|
||||
|
||||
class UNIXPasswordDatabaseTests(TestCase):
|
||||
"""
|
||||
Tests for L{UNIXPasswordDatabase}.
|
||||
"""
|
||||
skip = cryptSkip or dependencySkip
|
||||
|
||||
def assertLoggedIn(self, d, username):
|
||||
"""
|
||||
Assert that the L{Deferred} passed in is called back with the value
|
||||
'username'. This represents a valid login for this TestCase.
|
||||
|
||||
NOTE: To work, this method's return value must be returned from the
|
||||
test method, or otherwise hooked up to the test machinery.
|
||||
|
||||
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
|
||||
@type d: L{Deferred}
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
result = []
|
||||
d.addBoth(result.append)
|
||||
self.assertEqual(len(result), 1, "login incomplete")
|
||||
if isinstance(result[0], Failure):
|
||||
result[0].raiseException()
|
||||
self.assertEqual(result[0], username)
|
||||
|
||||
|
||||
def test_defaultCheckers(self):
|
||||
"""
|
||||
L{UNIXPasswordDatabase} with no arguments has checks the C{pwd} database
|
||||
and then the C{spwd} database.
|
||||
"""
|
||||
checker = checkers.UNIXPasswordDatabase()
|
||||
|
||||
def crypted(username, password):
|
||||
salt = crypt.crypt(password, username)
|
||||
crypted = crypt.crypt(password, '$1$' + salt)
|
||||
return crypted
|
||||
|
||||
pwd = UserDatabase()
|
||||
pwd.addUser('alice', crypted('alice', 'password'),
|
||||
1, 2, 'foo', '/foo', '/bin/sh')
|
||||
# x and * are convention for "look elsewhere for the password"
|
||||
pwd.addUser('bob', 'x', 1, 2, 'bar', '/bar', '/bin/sh')
|
||||
spwd = ShadowDatabase()
|
||||
spwd.addUser('alice', 'wrong', 1, 2, 3, 4, 5, 6, 7)
|
||||
spwd.addUser('bob', crypted('bob', 'password'),
|
||||
8, 9, 10, 11, 12, 13, 14)
|
||||
|
||||
self.patch(checkers, 'pwd', pwd)
|
||||
self.patch(checkers, 'spwd', spwd)
|
||||
|
||||
mockos = MockOS()
|
||||
self.patch(util, 'os', mockos)
|
||||
|
||||
mockos.euid = 2345
|
||||
mockos.egid = 1234
|
||||
|
||||
cred = UsernamePassword("alice", "password")
|
||||
self.assertLoggedIn(checker.requestAvatarId(cred), 'alice')
|
||||
self.assertEqual(mockos.seteuidCalls, [])
|
||||
self.assertEqual(mockos.setegidCalls, [])
|
||||
cred.username = "bob"
|
||||
self.assertLoggedIn(checker.requestAvatarId(cred), 'bob')
|
||||
self.assertEqual(mockos.seteuidCalls, [0, 2345])
|
||||
self.assertEqual(mockos.setegidCalls, [0, 1234])
|
||||
|
||||
|
||||
def assertUnauthorizedLogin(self, d):
|
||||
"""
|
||||
Asserts that the L{Deferred} passed in is erred back with an
|
||||
L{UnauthorizedLogin} L{Failure}. This reprsents an invalid login for
|
||||
this TestCase.
|
||||
|
||||
NOTE: To work, this method's return value must be returned from the
|
||||
test method, or otherwise hooked up to the test machinery.
|
||||
|
||||
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
|
||||
@type d: L{Deferred}
|
||||
@rtype: L{None}
|
||||
"""
|
||||
self.assertRaises(
|
||||
checkers.UnauthorizedLogin, self.assertLoggedIn, d, 'bogus value')
|
||||
|
||||
|
||||
def test_passInCheckers(self):
|
||||
"""
|
||||
L{UNIXPasswordDatabase} takes a list of functions to check for UNIX
|
||||
user information.
|
||||
"""
|
||||
password = crypt.crypt('secret', 'secret')
|
||||
userdb = UserDatabase()
|
||||
userdb.addUser('anybody', password, 1, 2, 'foo', '/bar', '/bin/sh')
|
||||
checker = checkers.UNIXPasswordDatabase([userdb.getpwnam])
|
||||
self.assertLoggedIn(
|
||||
checker.requestAvatarId(UsernamePassword('anybody', 'secret')),
|
||||
'anybody')
|
||||
|
||||
|
||||
def test_verifyPassword(self):
|
||||
"""
|
||||
If the encrypted password provided by the getpwnam function is valid
|
||||
(verified by the L{verifyCryptedPassword} function), we callback the
|
||||
C{requestAvatarId} L{Deferred} with the username.
|
||||
"""
|
||||
def verifyCryptedPassword(crypted, pw):
|
||||
return crypted == pw
|
||||
def getpwnam(username):
|
||||
return [username, username]
|
||||
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
|
||||
checker = checkers.UNIXPasswordDatabase([getpwnam])
|
||||
credential = UsernamePassword('username', 'username')
|
||||
self.assertLoggedIn(checker.requestAvatarId(credential), 'username')
|
||||
|
||||
|
||||
def test_failOnKeyError(self):
|
||||
"""
|
||||
If the getpwnam function raises a KeyError, the login fails with an
|
||||
L{UnauthorizedLogin} exception.
|
||||
"""
|
||||
def getpwnam(username):
|
||||
raise KeyError(username)
|
||||
checker = checkers.UNIXPasswordDatabase([getpwnam])
|
||||
credential = UsernamePassword('username', 'username')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
|
||||
|
||||
|
||||
def test_failOnBadPassword(self):
|
||||
"""
|
||||
If the verifyCryptedPassword function doesn't verify the password, the
|
||||
login fails with an L{UnauthorizedLogin} exception.
|
||||
"""
|
||||
def verifyCryptedPassword(crypted, pw):
|
||||
return False
|
||||
def getpwnam(username):
|
||||
return [username, username]
|
||||
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
|
||||
checker = checkers.UNIXPasswordDatabase([getpwnam])
|
||||
credential = UsernamePassword('username', 'username')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
|
||||
|
||||
|
||||
def test_loopThroughFunctions(self):
|
||||
"""
|
||||
UNIXPasswordDatabase.requestAvatarId loops through each getpwnam
|
||||
function associated with it and returns a L{Deferred} which fires with
|
||||
the result of the first one which returns a value other than None.
|
||||
ones do not verify the password.
|
||||
"""
|
||||
def verifyCryptedPassword(crypted, pw):
|
||||
return crypted == pw
|
||||
def getpwnam1(username):
|
||||
return [username, 'not the password']
|
||||
def getpwnam2(username):
|
||||
return [username, username]
|
||||
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
|
||||
checker = checkers.UNIXPasswordDatabase([getpwnam1, getpwnam2])
|
||||
credential = UsernamePassword('username', 'username')
|
||||
self.assertLoggedIn(checker.requestAvatarId(credential), 'username')
|
||||
|
||||
|
||||
def test_failOnSpecial(self):
|
||||
"""
|
||||
If the password returned by any function is C{""}, C{"x"}, or C{"*"} it
|
||||
is not compared against the supplied password. Instead it is skipped.
|
||||
"""
|
||||
pwd = UserDatabase()
|
||||
pwd.addUser('alice', '', 1, 2, '', 'foo', 'bar')
|
||||
pwd.addUser('bob', 'x', 1, 2, '', 'foo', 'bar')
|
||||
pwd.addUser('carol', '*', 1, 2, '', 'foo', 'bar')
|
||||
self.patch(checkers, 'pwd', pwd)
|
||||
|
||||
checker = checkers.UNIXPasswordDatabase([checkers._pwdGetByName])
|
||||
cred = UsernamePassword('alice', '')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
|
||||
|
||||
cred = UsernamePassword('bob', 'x')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
|
||||
|
||||
cred = UsernamePassword('carol', '*')
|
||||
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
|
||||
|
||||
|
||||
|
||||
class AuthorizedKeyFileReaderTests(TestCase):
|
||||
"""
|
||||
Tests for L{checkers.readAuthorizedKeyFile}
|
||||
"""
|
||||
skip = dependencySkip
|
||||
|
||||
|
||||
def test_ignoresComments(self):
|
||||
"""
|
||||
L{checkers.readAuthorizedKeyFile} does not attempt to turn comments
|
||||
into keys
|
||||
"""
|
||||
fileobj = StringIO(u'# this comment is ignored\n'
|
||||
u'this is not\n'
|
||||
u'# this is again\n'
|
||||
u'and this is not')
|
||||
result = checkers.readAuthorizedKeyFile(fileobj, lambda x: x)
|
||||
self.assertEqual(['this is not', 'and this is not'], list(result))
|
||||
|
||||
|
||||
def test_ignoresLeadingWhitespaceAndEmptyLines(self):
|
||||
"""
|
||||
L{checkers.readAuthorizedKeyFile} ignores leading whitespace in
|
||||
lines, as well as empty lines
|
||||
"""
|
||||
fileobj = StringIO(u"""
|
||||
# ignore
|
||||
not ignored
|
||||
""")
|
||||
result = checkers.readAuthorizedKeyFile(fileobj, parseKey=lambda x: x)
|
||||
self.assertEqual(['not ignored'], list(result))
|
||||
|
||||
|
||||
def test_ignoresUnparsableKeys(self):
|
||||
"""
|
||||
L{checkers.readAuthorizedKeyFile} does not raise an exception
|
||||
when a key fails to parse (raises a
|
||||
L{twisted.conch.ssh.keys.BadKeyError}), but rather just keeps going
|
||||
"""
|
||||
def failOnSome(line):
|
||||
if line.startswith('f'):
|
||||
raise keys.BadKeyError('failed to parse')
|
||||
return line
|
||||
|
||||
fileobj = StringIO(u'failed key\ngood key')
|
||||
result = checkers.readAuthorizedKeyFile(fileobj,
|
||||
parseKey=failOnSome)
|
||||
self.assertEqual(['good key'], list(result))
|
||||
|
||||
|
||||
|
||||
class InMemorySSHKeyDBTests(TestCase):
|
||||
"""
|
||||
Tests for L{checkers.InMemorySSHKeyDB}
|
||||
"""
|
||||
skip = dependencySkip
|
||||
|
||||
|
||||
def test_implementsInterface(self):
|
||||
"""
|
||||
L{checkers.InMemorySSHKeyDB} implements
|
||||
L{checkers.IAuthorizedKeysDB}
|
||||
"""
|
||||
keydb = checkers.InMemorySSHKeyDB({'alice': ['key']})
|
||||
verifyObject(checkers.IAuthorizedKeysDB, keydb)
|
||||
|
||||
|
||||
def test_noKeysForUnauthorizedUser(self):
|
||||
"""
|
||||
If the user is not in the mapping provided to
|
||||
L{checkers.InMemorySSHKeyDB}, an empty iterator is returned
|
||||
by L{checkers.InMemorySSHKeyDB.getAuthorizedKeys}
|
||||
"""
|
||||
keydb = checkers.InMemorySSHKeyDB({'alice': ['keys']})
|
||||
self.assertEqual([], list(keydb.getAuthorizedKeys('bob')))
|
||||
|
||||
|
||||
def test_allKeysForAuthorizedUser(self):
|
||||
"""
|
||||
If the user is in the mapping provided to
|
||||
L{checkers.InMemorySSHKeyDB}, an iterator with all the keys
|
||||
is returned by L{checkers.InMemorySSHKeyDB.getAuthorizedKeys}
|
||||
"""
|
||||
keydb = checkers.InMemorySSHKeyDB({'alice': ['a', 'b']})
|
||||
self.assertEqual(['a', 'b'], list(keydb.getAuthorizedKeys('alice')))
|
||||
|
||||
|
||||
|
||||
class UNIXAuthorizedKeysFilesTests(TestCase):
|
||||
"""
|
||||
Tests for L{checkers.UNIXAuthorizedKeysFiles}.
|
||||
"""
|
||||
skip = dependencySkip
|
||||
|
||||
|
||||
def setUp(self):
|
||||
mockos = MockOS()
|
||||
mockos.path = FilePath(self.mktemp())
|
||||
mockos.path.makedirs()
|
||||
|
||||
self.userdb = UserDatabase()
|
||||
self.userdb.addUser('alice', 'password', 1, 2, 'alice lastname',
|
||||
mockos.path.path, '/bin/shell')
|
||||
|
||||
self.sshDir = mockos.path.child('.ssh')
|
||||
self.sshDir.makedirs()
|
||||
authorizedKeys = self.sshDir.child('authorized_keys')
|
||||
authorizedKeys.setContent('key 1\nkey 2')
|
||||
|
||||
self.expectedKeys = ['key 1', 'key 2']
|
||||
|
||||
|
||||
def test_implementsInterface(self):
|
||||
"""
|
||||
L{checkers.UNIXAuthorizedKeysFiles} implements
|
||||
L{checkers.IAuthorizedKeysDB}.
|
||||
"""
|
||||
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb)
|
||||
verifyObject(checkers.IAuthorizedKeysDB, keydb)
|
||||
|
||||
|
||||
def test_noKeysForUnauthorizedUser(self):
|
||||
"""
|
||||
If the user is not in the user database provided to
|
||||
L{checkers.UNIXAuthorizedKeysFiles}, an empty iterator is returned
|
||||
by L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys}.
|
||||
"""
|
||||
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
|
||||
parseKey=lambda x: x)
|
||||
self.assertEqual([], list(keydb.getAuthorizedKeys('bob')))
|
||||
|
||||
|
||||
def test_allKeysInAllAuthorizedFilesForAuthorizedUser(self):
|
||||
"""
|
||||
If the user is in the user database provided to
|
||||
L{checkers.UNIXAuthorizedKeysFiles}, an iterator with all the keys in
|
||||
C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2} is returned
|
||||
by L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys}.
|
||||
"""
|
||||
self.sshDir.child('authorized_keys2').setContent('key 3')
|
||||
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
|
||||
parseKey=lambda x: x)
|
||||
self.assertEqual(self.expectedKeys + ['key 3'],
|
||||
list(keydb.getAuthorizedKeys('alice')))
|
||||
|
||||
|
||||
def test_ignoresNonexistantFile(self):
|
||||
"""
|
||||
L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys} returns only
|
||||
the keys in C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2}
|
||||
if they exist.
|
||||
"""
|
||||
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
|
||||
parseKey=lambda x: x)
|
||||
self.assertEqual(self.expectedKeys,
|
||||
list(keydb.getAuthorizedKeys('alice')))
|
||||
|
||||
|
||||
def test_ignoresUnreadableFile(self):
|
||||
"""
|
||||
L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys} returns only
|
||||
the keys in C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2}
|
||||
if they are readable.
|
||||
"""
|
||||
self.sshDir.child('authorized_keys2').makedirs()
|
||||
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
|
||||
parseKey=lambda x: x)
|
||||
self.assertEqual(self.expectedKeys,
|
||||
list(keydb.getAuthorizedKeys('alice')))
|
||||
|
||||
|
||||
|
||||
_KeyDB = namedtuple('KeyDB', ['getAuthorizedKeys'])
|
||||
|
||||
|
||||
|
||||
class _DummyException(Exception):
|
||||
"""
|
||||
Fake exception to be used for testing.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class SSHPublicKeyCheckerTests(TestCase):
|
||||
"""
|
||||
Tests for L{checkers.SSHPublicKeyChecker}.
|
||||
"""
|
||||
skip = dependencySkip
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.credentials = SSHPrivateKey(
|
||||
'alice', 'ssh-rsa', keydata.publicRSA_openssh, 'foo',
|
||||
keys.Key.fromString(keydata.privateRSA_openssh).sign('foo'))
|
||||
self.keydb = _KeyDB(lambda _: [
|
||||
keys.Key.fromString(keydata.publicRSA_openssh)])
|
||||
self.checker = checkers.SSHPublicKeyChecker(self.keydb)
|
||||
|
||||
|
||||
def test_credentialsWithoutSignature(self):
|
||||
"""
|
||||
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
|
||||
credentials that do not have a signature fails with L{ValidPublicKey}.
|
||||
"""
|
||||
self.credentials.signature = None
|
||||
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
|
||||
ValidPublicKey)
|
||||
|
||||
|
||||
def test_credentialsWithBadKey(self):
|
||||
"""
|
||||
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
|
||||
credentials that have a bad key fails with L{keys.BadKeyError}.
|
||||
"""
|
||||
self.credentials.blob = ''
|
||||
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
|
||||
keys.BadKeyError)
|
||||
|
||||
|
||||
def test_credentialsNoMatchingKey(self):
|
||||
"""
|
||||
If L{checkers.IAuthorizedKeysDB.getAuthorizedKeys} returns no keys
|
||||
that match the credentials,
|
||||
L{checkers.SSHPublicKeyChecker.requestAvatarId} fails with
|
||||
L{UnauthorizedLogin}.
|
||||
"""
|
||||
self.credentials.blob = keydata.publicDSA_openssh
|
||||
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
|
||||
UnauthorizedLogin)
|
||||
|
||||
|
||||
def test_credentialsInvalidSignature(self):
|
||||
"""
|
||||
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
|
||||
credentials that are incorrectly signed fails with
|
||||
L{UnauthorizedLogin}.
|
||||
"""
|
||||
self.credentials.signature = (
|
||||
keys.Key.fromString(keydata.privateDSA_openssh).sign('foo'))
|
||||
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
|
||||
UnauthorizedLogin)
|
||||
|
||||
|
||||
def test_failureVerifyingKey(self):
|
||||
"""
|
||||
If L{keys.Key.verify} raises an exception,
|
||||
L{checkers.SSHPublicKeyChecker.requestAvatarId} fails with
|
||||
L{UnauthorizedLogin}.
|
||||
"""
|
||||
def fail(*args, **kwargs):
|
||||
raise _DummyException()
|
||||
|
||||
self.patch(keys.Key, 'verify', fail)
|
||||
|
||||
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
|
||||
UnauthorizedLogin)
|
||||
self.flushLoggedErrors(_DummyException)
|
||||
|
||||
|
||||
def test_usernameReturnedOnSuccess(self):
|
||||
"""
|
||||
L{checker.SSHPublicKeyChecker.requestAvatarId}, if successful,
|
||||
callbacks with the username.
|
||||
"""
|
||||
d = self.checker.requestAvatarId(self.credentials)
|
||||
self.assertEqual('alice', self.successResultOf(d))
|
@ -0,0 +1,339 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.scripts.ckeygen}.
|
||||
"""
|
||||
|
||||
import getpass
|
||||
import sys
|
||||
from StringIO import StringIO
|
||||
|
||||
from twisted.python.reflect import requireModule
|
||||
|
||||
if requireModule('Crypto') and requireModule('pyasn1'):
|
||||
from twisted.conch.ssh.keys import Key, BadKeyError
|
||||
from twisted.conch.scripts.ckeygen import (
|
||||
changePassPhrase, displayPublicKey, printFingerprint, _saveKey)
|
||||
else:
|
||||
skip = "PyCrypto and pyasn1 required for twisted.conch.scripts.ckeygen."
|
||||
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.conch.test.keydata import (
|
||||
publicRSA_openssh, privateRSA_openssh, privateRSA_openssh_encrypted)
|
||||
|
||||
|
||||
|
||||
def makeGetpass(*passphrases):
|
||||
"""
|
||||
Return a callable to patch C{getpass.getpass}. Yields a passphrase each
|
||||
time called. Use case is to provide an old, then new passphrase(s) as if
|
||||
requested interactively.
|
||||
|
||||
@param passphrases: The list of passphrases returned, one per each call.
|
||||
"""
|
||||
passphrases = iter(passphrases)
|
||||
|
||||
def fakeGetpass(_):
|
||||
return passphrases.next()
|
||||
|
||||
return fakeGetpass
|
||||
|
||||
|
||||
|
||||
class KeyGenTests(TestCase):
|
||||
"""
|
||||
Tests for various functions used to implement the I{ckeygen} script.
|
||||
"""
|
||||
def setUp(self):
|
||||
"""
|
||||
Patch C{sys.stdout} with a L{StringIO} instance to tests can make
|
||||
assertions about what's printed.
|
||||
"""
|
||||
self.stdout = StringIO()
|
||||
self.patch(sys, 'stdout', self.stdout)
|
||||
|
||||
|
||||
def test_printFingerprint(self):
|
||||
"""
|
||||
L{printFingerprint} writes a line to standard out giving the number of
|
||||
bits of the key, its fingerprint, and the basename of the file from it
|
||||
was read.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(publicRSA_openssh)
|
||||
printFingerprint({'filename': filename})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue(),
|
||||
'768 3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af temp\n')
|
||||
|
||||
|
||||
def test_saveKey(self):
|
||||
"""
|
||||
L{_saveKey} writes the private and public parts of a key to two
|
||||
different files and writes a report of this to standard out.
|
||||
"""
|
||||
base = FilePath(self.mktemp())
|
||||
base.makedirs()
|
||||
filename = base.child('id_rsa').path
|
||||
key = Key.fromString(privateRSA_openssh)
|
||||
_saveKey(
|
||||
key.keyObject,
|
||||
{'filename': filename, 'pass': 'passphrase'})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue(),
|
||||
"Your identification has been saved in %s\n"
|
||||
"Your public key has been saved in %s.pub\n"
|
||||
"The key fingerprint is:\n"
|
||||
"3d:13:5f:cb:c9:79:8a:93:06:27:65:bc:3d:0b:8f:af\n" % (
|
||||
filename,
|
||||
filename))
|
||||
self.assertEqual(
|
||||
key.fromString(
|
||||
base.child('id_rsa').getContent(), None, 'passphrase'),
|
||||
key)
|
||||
self.assertEqual(
|
||||
Key.fromString(base.child('id_rsa.pub').getContent()),
|
||||
key.public())
|
||||
|
||||
|
||||
def test_saveKeyEmptyPassphrase(self):
|
||||
"""
|
||||
L{_saveKey} will choose an empty string for the passphrase if
|
||||
no-passphrase is C{True}.
|
||||
"""
|
||||
base = FilePath(self.mktemp())
|
||||
base.makedirs()
|
||||
filename = base.child('id_rsa').path
|
||||
key = Key.fromString(privateRSA_openssh)
|
||||
_saveKey(
|
||||
key.keyObject,
|
||||
{'filename': filename, 'no-passphrase': True})
|
||||
self.assertEqual(
|
||||
key.fromString(
|
||||
base.child('id_rsa').getContent(), None, b''),
|
||||
key)
|
||||
|
||||
|
||||
def test_displayPublicKey(self):
|
||||
"""
|
||||
L{displayPublicKey} prints out the public key associated with a given
|
||||
private key.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
pubKey = Key.fromString(publicRSA_openssh)
|
||||
FilePath(filename).setContent(privateRSA_openssh)
|
||||
displayPublicKey({'filename': filename})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
pubKey.toString('openssh'))
|
||||
|
||||
|
||||
def test_displayPublicKeyEncrypted(self):
|
||||
"""
|
||||
L{displayPublicKey} prints out the public key associated with a given
|
||||
private key using the given passphrase when it's encrypted.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
pubKey = Key.fromString(publicRSA_openssh)
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
displayPublicKey({'filename': filename, 'pass': 'encrypted'})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
pubKey.toString('openssh'))
|
||||
|
||||
|
||||
def test_displayPublicKeyEncryptedPassphrasePrompt(self):
|
||||
"""
|
||||
L{displayPublicKey} prints out the public key associated with a given
|
||||
private key, asking for the passphrase when it's encrypted.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
pubKey = Key.fromString(publicRSA_openssh)
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
self.patch(getpass, 'getpass', lambda x: 'encrypted')
|
||||
displayPublicKey({'filename': filename})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
pubKey.toString('openssh'))
|
||||
|
||||
|
||||
def test_displayPublicKeyWrongPassphrase(self):
|
||||
"""
|
||||
L{displayPublicKey} fails with a L{BadKeyError} when trying to decrypt
|
||||
an encrypted key with the wrong password.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
self.assertRaises(
|
||||
BadKeyError, displayPublicKey,
|
||||
{'filename': filename, 'pass': 'wrong'})
|
||||
|
||||
|
||||
def test_changePassphrase(self):
|
||||
"""
|
||||
L{changePassPhrase} allows a user to change the passphrase of a
|
||||
private key interactively.
|
||||
"""
|
||||
oldNewConfirm = makeGetpass('encrypted', 'newpass', 'newpass')
|
||||
self.patch(getpass, 'getpass', oldNewConfirm)
|
||||
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
|
||||
changePassPhrase({'filename': filename})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
'Your identification has been saved with the new passphrase.')
|
||||
self.assertNotEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseWithOld(self):
|
||||
"""
|
||||
L{changePassPhrase} allows a user to change the passphrase of a
|
||||
private key, providing the old passphrase and prompting for new one.
|
||||
"""
|
||||
newConfirm = makeGetpass('newpass', 'newpass')
|
||||
self.patch(getpass, 'getpass', newConfirm)
|
||||
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
|
||||
changePassPhrase({'filename': filename, 'pass': 'encrypted'})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
'Your identification has been saved with the new passphrase.')
|
||||
self.assertNotEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseWithBoth(self):
|
||||
"""
|
||||
L{changePassPhrase} allows a user to change the passphrase of a private
|
||||
key by providing both old and new passphrases without prompting.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
|
||||
changePassPhrase(
|
||||
{'filename': filename, 'pass': 'encrypted',
|
||||
'newpass': 'newencrypt'})
|
||||
self.assertEqual(
|
||||
self.stdout.getvalue().strip('\n'),
|
||||
'Your identification has been saved with the new passphrase.')
|
||||
self.assertNotEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseWrongPassphrase(self):
|
||||
"""
|
||||
L{changePassPhrase} exits if passed an invalid old passphrase when
|
||||
trying to change the passphrase of a private key.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase,
|
||||
{'filename': filename, 'pass': 'wrong'})
|
||||
self.assertEqual('Could not change passphrase: old passphrase error',
|
||||
str(error))
|
||||
self.assertEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseEmptyGetPass(self):
|
||||
"""
|
||||
L{changePassPhrase} exits if no passphrase is specified for the
|
||||
C{getpass} call and the key is encrypted.
|
||||
"""
|
||||
self.patch(getpass, 'getpass', makeGetpass(''))
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh_encrypted)
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase, {'filename': filename})
|
||||
self.assertEqual(
|
||||
'Could not change passphrase: Passphrase must be provided '
|
||||
'for an encrypted key',
|
||||
str(error))
|
||||
self.assertEqual(privateRSA_openssh_encrypted,
|
||||
FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseBadKey(self):
|
||||
"""
|
||||
L{changePassPhrase} exits if the file specified points to an invalid
|
||||
key.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent('foobar')
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase, {'filename': filename})
|
||||
self.assertEqual(
|
||||
"Could not change passphrase: cannot guess the type of 'foobar'",
|
||||
str(error))
|
||||
self.assertEqual('foobar', FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseCreateError(self):
|
||||
"""
|
||||
L{changePassPhrase} doesn't modify the key file if an unexpected error
|
||||
happens when trying to create the key with the new passphrase.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh)
|
||||
|
||||
def toString(*args, **kwargs):
|
||||
raise RuntimeError('oops')
|
||||
|
||||
self.patch(Key, 'toString', toString)
|
||||
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase,
|
||||
{'filename': filename,
|
||||
'newpass': 'newencrypt'})
|
||||
|
||||
self.assertEqual(
|
||||
'Could not change passphrase: oops', str(error))
|
||||
|
||||
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphraseEmptyStringError(self):
|
||||
"""
|
||||
L{changePassPhrase} doesn't modify the key file if C{toString} returns
|
||||
an empty string.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(privateRSA_openssh)
|
||||
|
||||
def toString(*args, **kwargs):
|
||||
return ''
|
||||
|
||||
self.patch(Key, 'toString', toString)
|
||||
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase,
|
||||
{'filename': filename, 'newpass': 'newencrypt'})
|
||||
|
||||
self.assertEqual(
|
||||
"Could not change passphrase: "
|
||||
"cannot guess the type of ''", str(error))
|
||||
|
||||
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
|
||||
|
||||
|
||||
def test_changePassphrasePublicKey(self):
|
||||
"""
|
||||
L{changePassPhrase} exits when trying to change the passphrase on a
|
||||
public key, and doesn't change the file.
|
||||
"""
|
||||
filename = self.mktemp()
|
||||
FilePath(filename).setContent(publicRSA_openssh)
|
||||
error = self.assertRaises(
|
||||
SystemExit, changePassPhrase,
|
||||
{'filename': filename, 'newpass': 'pass'})
|
||||
self.assertEqual(
|
||||
'Could not change passphrase: key not encrypted', str(error))
|
||||
self.assertEqual(publicRSA_openssh, FilePath(filename).getContent())
|
@ -0,0 +1,577 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_conch -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
import os, sys, socket
|
||||
from itertools import count
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.cred import portal
|
||||
from twisted.internet import reactor, defer, protocol
|
||||
from twisted.internet.error import ProcessExitedAlready
|
||||
from twisted.internet.task import LoopingCall
|
||||
from twisted.python import log, runtime
|
||||
from twisted.trial import unittest
|
||||
from twisted.conch.error import ConchError
|
||||
from twisted.conch.avatar import ConchUser
|
||||
from twisted.conch.ssh.session import ISession, SSHSession, wrapProtocol
|
||||
|
||||
try:
|
||||
from twisted.conch.scripts.conch import SSHSession as StdioInteractingSession
|
||||
except ImportError, e:
|
||||
StdioInteractingSession = None
|
||||
_reason = str(e)
|
||||
del e
|
||||
|
||||
from twisted.conch.test.test_ssh import ConchTestRealm
|
||||
from twisted.python.procutils import which
|
||||
|
||||
from twisted.conch.test.keydata import publicRSA_openssh, privateRSA_openssh
|
||||
from twisted.conch.test.keydata import publicDSA_openssh, privateDSA_openssh
|
||||
|
||||
from twisted.conch.test.test_ssh import Crypto, pyasn1
|
||||
try:
|
||||
from twisted.conch.test.test_ssh import ConchTestServerFactory, \
|
||||
conchTestPublicKeyChecker
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class FakeStdio(object):
|
||||
"""
|
||||
A fake for testing L{twisted.conch.scripts.conch.SSHSession.eofReceived} and
|
||||
L{twisted.conch.scripts.cftp.SSHSession.eofReceived}.
|
||||
|
||||
@ivar writeConnLost: A flag which records whether L{loserWriteConnection}
|
||||
has been called.
|
||||
"""
|
||||
writeConnLost = False
|
||||
|
||||
def loseWriteConnection(self):
|
||||
"""
|
||||
Record the call to loseWriteConnection.
|
||||
"""
|
||||
self.writeConnLost = True
|
||||
|
||||
|
||||
|
||||
class StdioInteractingSessionTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{twisted.conch.scripts.conch.SSHSession}.
|
||||
"""
|
||||
if StdioInteractingSession is None:
|
||||
skip = _reason
|
||||
|
||||
def test_eofReceived(self):
|
||||
"""
|
||||
L{twisted.conch.scripts.conch.SSHSession.eofReceived} loses the
|
||||
write half of its stdio connection.
|
||||
"""
|
||||
stdio = FakeStdio()
|
||||
channel = StdioInteractingSession()
|
||||
channel.stdio = stdio
|
||||
channel.eofReceived()
|
||||
self.assertTrue(stdio.writeConnLost)
|
||||
|
||||
|
||||
|
||||
class Echo(protocol.Protocol):
|
||||
def connectionMade(self):
|
||||
log.msg('ECHO CONNECTION MADE')
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
log.msg('ECHO CONNECTION DONE')
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.transport.write(data)
|
||||
if '\n' in data:
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
|
||||
class EchoFactory(protocol.Factory):
|
||||
protocol = Echo
|
||||
|
||||
|
||||
|
||||
class ConchTestOpenSSHProcess(protocol.ProcessProtocol):
|
||||
"""
|
||||
Test protocol for launching an OpenSSH client process.
|
||||
|
||||
@ivar deferred: Set by whatever uses this object. Accessed using
|
||||
L{_getDeferred}, which destroys the value so the Deferred is not
|
||||
fired twice. Fires when the process is terminated.
|
||||
"""
|
||||
|
||||
deferred = None
|
||||
buf = ''
|
||||
|
||||
def _getDeferred(self):
|
||||
d, self.deferred = self.deferred, None
|
||||
return d
|
||||
|
||||
|
||||
def outReceived(self, data):
|
||||
self.buf += data
|
||||
|
||||
|
||||
def processEnded(self, reason):
|
||||
"""
|
||||
Called when the process has ended.
|
||||
|
||||
@param reason: a Failure giving the reason for the process' end.
|
||||
"""
|
||||
if reason.value.exitCode != 0:
|
||||
self._getDeferred().errback(
|
||||
ConchError("exit code was not 0: %s" %
|
||||
reason.value.exitCode))
|
||||
else:
|
||||
buf = self.buf.replace('\r\n', '\n')
|
||||
self._getDeferred().callback(buf)
|
||||
|
||||
|
||||
|
||||
class ConchTestForwardingProcess(protocol.ProcessProtocol):
|
||||
"""
|
||||
Manages a third-party process which launches a server.
|
||||
|
||||
Uses L{ConchTestForwardingPort} to connect to the third-party server.
|
||||
Once L{ConchTestForwardingPort} has disconnected, kill the process and fire
|
||||
a Deferred with the data received by the L{ConchTestForwardingPort}.
|
||||
|
||||
@ivar deferred: Set by whatever uses this object. Accessed using
|
||||
L{_getDeferred}, which destroys the value so the Deferred is not
|
||||
fired twice. Fires when the process is terminated.
|
||||
"""
|
||||
|
||||
deferred = None
|
||||
|
||||
def __init__(self, port, data):
|
||||
"""
|
||||
@type port: C{int}
|
||||
@param port: The port on which the third-party server is listening.
|
||||
(it is assumed that the server is running on localhost).
|
||||
|
||||
@type data: C{str}
|
||||
@param data: This is sent to the third-party server. Must end with '\n'
|
||||
in order to trigger a disconnect.
|
||||
"""
|
||||
self.port = port
|
||||
self.buffer = None
|
||||
self.data = data
|
||||
|
||||
|
||||
def _getDeferred(self):
|
||||
d, self.deferred = self.deferred, None
|
||||
return d
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self._connect()
|
||||
|
||||
|
||||
def _connect(self):
|
||||
"""
|
||||
Connect to the server, which is often a third-party process.
|
||||
Tries to reconnect if it fails because we have no way of determining
|
||||
exactly when the port becomes available for listening -- we can only
|
||||
know when the process starts.
|
||||
"""
|
||||
cc = protocol.ClientCreator(reactor, ConchTestForwardingPort, self,
|
||||
self.data)
|
||||
d = cc.connectTCP('127.0.0.1', self.port)
|
||||
d.addErrback(self._ebConnect)
|
||||
return d
|
||||
|
||||
|
||||
def _ebConnect(self, f):
|
||||
reactor.callLater(.1, self._connect)
|
||||
|
||||
|
||||
def forwardingPortDisconnected(self, buffer):
|
||||
"""
|
||||
The network connection has died; save the buffer of output
|
||||
from the network and attempt to quit the process gracefully,
|
||||
and then (after the reactor has spun) send it a KILL signal.
|
||||
"""
|
||||
self.buffer = buffer
|
||||
self.transport.write('\x03')
|
||||
self.transport.loseConnection()
|
||||
reactor.callLater(0, self._reallyDie)
|
||||
|
||||
|
||||
def _reallyDie(self):
|
||||
try:
|
||||
self.transport.signalProcess('KILL')
|
||||
except ProcessExitedAlready:
|
||||
pass
|
||||
|
||||
|
||||
def processEnded(self, reason):
|
||||
"""
|
||||
Fire the Deferred at self.deferred with the data collected
|
||||
from the L{ConchTestForwardingPort} connection, if any.
|
||||
"""
|
||||
self._getDeferred().callback(self.buffer)
|
||||
|
||||
|
||||
|
||||
class ConchTestForwardingPort(protocol.Protocol):
|
||||
"""
|
||||
Connects to server launched by a third-party process (managed by
|
||||
L{ConchTestForwardingProcess}) sends data, then reports whatever it
|
||||
received back to the L{ConchTestForwardingProcess} once the connection
|
||||
is ended.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, protocol, data):
|
||||
"""
|
||||
@type protocol: L{ConchTestForwardingProcess}
|
||||
@param protocol: The L{ProcessProtocol} which made this connection.
|
||||
|
||||
@type data: str
|
||||
@param data: The data to be sent to the third-party server.
|
||||
"""
|
||||
self.protocol = protocol
|
||||
self.data = data
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self.buffer = ''
|
||||
self.transport.write(self.data)
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.buffer += data
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.protocol.forwardingPortDisconnected(self.buffer)
|
||||
|
||||
|
||||
|
||||
def _makeArgs(args, mod="conch"):
|
||||
start = [sys.executable, '-c'
|
||||
"""
|
||||
### Twisted Preamble
|
||||
import sys, os
|
||||
path = os.path.abspath(sys.argv[0])
|
||||
while os.path.dirname(path) != path:
|
||||
if os.path.basename(path).startswith('Twisted'):
|
||||
sys.path.insert(0, path)
|
||||
break
|
||||
path = os.path.dirname(path)
|
||||
|
||||
from twisted.conch.scripts.%s import run
|
||||
run()""" % mod]
|
||||
return start + list(args)
|
||||
|
||||
|
||||
|
||||
class ConchServerSetupMixin:
|
||||
if not Crypto:
|
||||
skip = "can't run w/o PyCrypto"
|
||||
|
||||
if not pyasn1:
|
||||
skip = "Cannot run without PyASN1"
|
||||
|
||||
realmFactory = staticmethod(lambda: ConchTestRealm('testuser'))
|
||||
|
||||
def _createFiles(self):
|
||||
for f in ['rsa_test','rsa_test.pub','dsa_test','dsa_test.pub',
|
||||
'kh_test']:
|
||||
if os.path.exists(f):
|
||||
os.remove(f)
|
||||
open('rsa_test','w').write(privateRSA_openssh)
|
||||
open('rsa_test.pub','w').write(publicRSA_openssh)
|
||||
open('dsa_test.pub','w').write(publicDSA_openssh)
|
||||
open('dsa_test','w').write(privateDSA_openssh)
|
||||
os.chmod('dsa_test', 33152)
|
||||
os.chmod('rsa_test', 33152)
|
||||
open('kh_test','w').write('127.0.0.1 '+publicRSA_openssh)
|
||||
|
||||
|
||||
def _getFreePort(self):
|
||||
s = socket.socket()
|
||||
s.bind(('', 0))
|
||||
port = s.getsockname()[1]
|
||||
s.close()
|
||||
return port
|
||||
|
||||
|
||||
def _makeConchFactory(self):
|
||||
"""
|
||||
Make a L{ConchTestServerFactory}, which allows us to start a
|
||||
L{ConchTestServer} -- i.e. an actually listening conch.
|
||||
"""
|
||||
realm = self.realmFactory()
|
||||
p = portal.Portal(realm)
|
||||
p.registerChecker(conchTestPublicKeyChecker())
|
||||
factory = ConchTestServerFactory()
|
||||
factory.portal = p
|
||||
return factory
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self._createFiles()
|
||||
self.conchFactory = self._makeConchFactory()
|
||||
self.conchFactory.expectedLoseConnection = 1
|
||||
self.conchServer = reactor.listenTCP(0, self.conchFactory,
|
||||
interface="127.0.0.1")
|
||||
self.echoServer = reactor.listenTCP(0, EchoFactory())
|
||||
self.echoPort = self.echoServer.getHost().port
|
||||
self.echoServerV6 = reactor.listenTCP(0, EchoFactory(), interface="::1")
|
||||
self.echoPortV6 = self.echoServerV6.getHost().port
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
try:
|
||||
self.conchFactory.proto.done = 1
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
self.conchFactory.proto.transport.loseConnection()
|
||||
return defer.gatherResults([
|
||||
defer.maybeDeferred(self.conchServer.stopListening),
|
||||
defer.maybeDeferred(self.echoServer.stopListening),
|
||||
defer.maybeDeferred(self.echoServerV6.stopListening)])
|
||||
|
||||
|
||||
|
||||
class ForwardingMixin(ConchServerSetupMixin):
|
||||
"""
|
||||
Template class for tests of the Conch server's ability to forward arbitrary
|
||||
protocols over SSH.
|
||||
|
||||
These tests are integration tests, not unit tests. They launch a Conch
|
||||
server, a custom TCP server (just an L{EchoProtocol}) and then call
|
||||
L{execute}.
|
||||
|
||||
L{execute} is implemented by subclasses of L{ForwardingMixin}. It should
|
||||
cause an SSH client to connect to the Conch server, asking it to forward
|
||||
data to the custom TCP server.
|
||||
"""
|
||||
|
||||
def test_exec(self):
|
||||
"""
|
||||
Test that we can use whatever client to send the command "echo goodbye"
|
||||
to the Conch server. Make sure we receive "goodbye" back from the
|
||||
server.
|
||||
"""
|
||||
d = self.execute('echo goodbye', ConchTestOpenSSHProcess())
|
||||
return d.addCallback(self.assertEqual, 'goodbye\n')
|
||||
|
||||
|
||||
def test_localToRemoteForwarding(self):
|
||||
"""
|
||||
Test that we can use whatever client to forward a local port to a
|
||||
specified port on the server.
|
||||
"""
|
||||
localPort = self._getFreePort()
|
||||
process = ConchTestForwardingProcess(localPort, 'test\n')
|
||||
d = self.execute('', process,
|
||||
sshArgs='-N -L%i:127.0.0.1:%i'
|
||||
% (localPort, self.echoPort))
|
||||
d.addCallback(self.assertEqual, 'test\n')
|
||||
return d
|
||||
|
||||
|
||||
def test_remoteToLocalForwarding(self):
|
||||
"""
|
||||
Test that we can use whatever client to forward a port from the server
|
||||
to a port locally.
|
||||
"""
|
||||
localPort = self._getFreePort()
|
||||
process = ConchTestForwardingProcess(localPort, 'test\n')
|
||||
d = self.execute('', process,
|
||||
sshArgs='-N -R %i:127.0.0.1:%i'
|
||||
% (localPort, self.echoPort))
|
||||
d.addCallback(self.assertEqual, 'test\n')
|
||||
return d
|
||||
|
||||
|
||||
|
||||
# Conventionally there is a separate adapter object which provides ISession for
|
||||
# the user, but making the user provide ISession directly works too. This isn't
|
||||
# a full implementation of ISession though, just enough to make these tests
|
||||
# pass.
|
||||
@implementer(ISession)
|
||||
class RekeyAvatar(ConchUser):
|
||||
"""
|
||||
This avatar implements a shell which sends 60 numbered lines to whatever
|
||||
connects to it, then closes the session with a 0 exit status.
|
||||
|
||||
60 lines is selected as being enough to send more than 2kB of traffic, the
|
||||
amount the client is configured to initiate a rekey after.
|
||||
"""
|
||||
def __init__(self):
|
||||
ConchUser.__init__(self)
|
||||
self.channelLookup['session'] = SSHSession
|
||||
|
||||
|
||||
def openShell(self, transport):
|
||||
"""
|
||||
Write 60 lines of data to the transport, then exit.
|
||||
"""
|
||||
proto = protocol.Protocol()
|
||||
proto.makeConnection(transport)
|
||||
transport.makeConnection(wrapProtocol(proto))
|
||||
|
||||
# Send enough bytes to the connection so that a rekey is triggered in
|
||||
# the client.
|
||||
def write(counter):
|
||||
i = counter()
|
||||
if i == 60:
|
||||
call.stop()
|
||||
transport.session.conn.sendRequest(
|
||||
transport.session, 'exit-status', '\x00\x00\x00\x00')
|
||||
transport.loseConnection()
|
||||
else:
|
||||
transport.write("line #%02d\n" % (i,))
|
||||
|
||||
# The timing for this loop is an educated guess (and/or the result of
|
||||
# experimentation) to exercise the case where a packet is generated
|
||||
# mid-rekey. Since the other side of the connection is (so far) the
|
||||
# OpenSSH command line client, there's no easy way to determine when the
|
||||
# rekey has been initiated. If there were, then generating a packet
|
||||
# immediately at that time would be a better way to test the
|
||||
# functionality being tested here.
|
||||
call = LoopingCall(write, count().next)
|
||||
call.start(0.01)
|
||||
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
Ignore the close of the session.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class RekeyRealm:
|
||||
"""
|
||||
This realm gives out new L{RekeyAvatar} instances for any avatar request.
|
||||
"""
|
||||
def requestAvatar(self, avatarID, mind, *interfaces):
|
||||
return interfaces[0], RekeyAvatar(), lambda: None
|
||||
|
||||
|
||||
|
||||
class RekeyTestsMixin(ConchServerSetupMixin):
|
||||
"""
|
||||
TestCase mixin which defines tests exercising L{SSHTransportBase}'s handling
|
||||
of rekeying messages.
|
||||
"""
|
||||
realmFactory = RekeyRealm
|
||||
|
||||
def test_clientRekey(self):
|
||||
"""
|
||||
After a client-initiated rekey is completed, application data continues
|
||||
to be passed over the SSH connection.
|
||||
"""
|
||||
process = ConchTestOpenSSHProcess()
|
||||
d = self.execute("", process, '-o RekeyLimit=2K')
|
||||
def finished(result):
|
||||
self.assertEqual(
|
||||
result,
|
||||
'\n'.join(['line #%02d' % (i,) for i in range(60)]) + '\n')
|
||||
d.addCallback(finished)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class OpenSSHClientMixin:
|
||||
if not which('ssh'):
|
||||
skip = "no ssh command-line client available"
|
||||
|
||||
def execute(self, remoteCommand, process, sshArgs=''):
|
||||
"""
|
||||
Connects to the SSH server started in L{ConchServerSetupMixin.setUp} by
|
||||
running the 'ssh' command line tool.
|
||||
|
||||
@type remoteCommand: str
|
||||
@param remoteCommand: The command (with arguments) to run on the
|
||||
remote end.
|
||||
|
||||
@type process: L{ConchTestOpenSSHProcess}
|
||||
|
||||
@type sshArgs: str
|
||||
@param sshArgs: Arguments to pass to the 'ssh' process.
|
||||
|
||||
@return: L{defer.Deferred}
|
||||
"""
|
||||
process.deferred = defer.Deferred()
|
||||
cmdline = ('ssh -2 -l testuser -p %i '
|
||||
'-oUserKnownHostsFile=kh_test '
|
||||
'-oPasswordAuthentication=no '
|
||||
# Always use the RSA key, since that's the one in kh_test.
|
||||
'-oHostKeyAlgorithms=ssh-rsa '
|
||||
'-a '
|
||||
'-i dsa_test ') + sshArgs + \
|
||||
' 127.0.0.1 ' + remoteCommand
|
||||
port = self.conchServer.getHost().port
|
||||
cmds = (cmdline % port).split()
|
||||
reactor.spawnProcess(process, "ssh", cmds)
|
||||
return process.deferred
|
||||
|
||||
|
||||
|
||||
class OpenSSHClientForwardingTests(ForwardingMixin, OpenSSHClientMixin,
|
||||
unittest.TestCase):
|
||||
"""
|
||||
Connection forwarding tests run against the OpenSSL command line client.
|
||||
"""
|
||||
def test_localToRemoteForwardingV6(self):
|
||||
"""
|
||||
Forwarding of arbitrary IPv6 TCP connections via SSH.
|
||||
"""
|
||||
localPort = self._getFreePort()
|
||||
process = ConchTestForwardingProcess(localPort, 'test\n')
|
||||
d = self.execute('', process,
|
||||
sshArgs='-N -L%i:[::1]:%i'
|
||||
% (localPort, self.echoPortV6))
|
||||
d.addCallback(self.assertEqual, 'test\n')
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class OpenSSHClientRekeyTests(RekeyTestsMixin, OpenSSHClientMixin,
|
||||
unittest.TestCase):
|
||||
"""
|
||||
Rekeying tests run against the OpenSSL command line client.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class CmdLineClientTests(ForwardingMixin, unittest.TestCase):
|
||||
"""
|
||||
Connection forwarding tests run against the Conch command line client.
|
||||
"""
|
||||
if runtime.platformType == 'win32':
|
||||
skip = "can't run cmdline client on win32"
|
||||
|
||||
def execute(self, remoteCommand, process, sshArgs=''):
|
||||
"""
|
||||
As for L{OpenSSHClientTestCase.execute}, except it runs the 'conch'
|
||||
command line tool, not 'ssh'.
|
||||
"""
|
||||
process.deferred = defer.Deferred()
|
||||
port = self.conchServer.getHost().port
|
||||
cmd = ('-p %i -l testuser '
|
||||
'--known-hosts kh_test '
|
||||
'--user-authentications publickey '
|
||||
'--host-key-algorithms ssh-rsa '
|
||||
'-a '
|
||||
'-i dsa_test '
|
||||
'-v ') % port + sshArgs + \
|
||||
' 127.0.0.1 ' + remoteCommand
|
||||
cmds = _makeArgs(cmd.split())
|
||||
log.msg(str(cmds))
|
||||
env = os.environ.copy()
|
||||
env['PYTHONPATH'] = os.pathsep.join(sys.path)
|
||||
reactor.spawnProcess(process, sys.executable, cmds, env=env)
|
||||
return process.deferred
|
@ -0,0 +1,730 @@
|
||||
# Copyright (c) 2007-2010 Twisted Matrix Laboratories.
|
||||
# See LICENSE for details
|
||||
|
||||
"""
|
||||
This module tests twisted.conch.ssh.connection.
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.conch import error
|
||||
from twisted.conch.ssh import channel, common, connection
|
||||
from twisted.trial import unittest
|
||||
from twisted.conch.test import test_userauth
|
||||
|
||||
|
||||
class TestChannel(channel.SSHChannel):
|
||||
"""
|
||||
A mocked-up version of twisted.conch.ssh.channel.SSHChannel.
|
||||
|
||||
@ivar gotOpen: True if channelOpen has been called.
|
||||
@type gotOpen: C{bool}
|
||||
@ivar specificData: the specific channel open data passed to channelOpen.
|
||||
@type specificData: C{str}
|
||||
@ivar openFailureReason: the reason passed to openFailed.
|
||||
@type openFailed: C{error.ConchError}
|
||||
@ivar inBuffer: a C{list} of strings received by the channel.
|
||||
@type inBuffer: C{list}
|
||||
@ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by
|
||||
the channel.
|
||||
@type extBuffer: C{list}
|
||||
@ivar numberRequests: the number of requests that have been made to this
|
||||
channel.
|
||||
@type numberRequests: C{int}
|
||||
@ivar gotEOF: True if the other side sent EOF.
|
||||
@type gotEOF: C{bool}
|
||||
@ivar gotOneClose: True if the other side closed the connection.
|
||||
@type gotOneClose: C{bool}
|
||||
@ivar gotClosed: True if the channel is closed.
|
||||
@type gotClosed: C{bool}
|
||||
"""
|
||||
name = "TestChannel"
|
||||
gotOpen = False
|
||||
|
||||
def logPrefix(self):
|
||||
return "TestChannel %i" % self.id
|
||||
|
||||
def channelOpen(self, specificData):
|
||||
"""
|
||||
The channel is open. Set up the instance variables.
|
||||
"""
|
||||
self.gotOpen = True
|
||||
self.specificData = specificData
|
||||
self.inBuffer = []
|
||||
self.extBuffer = []
|
||||
self.numberRequests = 0
|
||||
self.gotEOF = False
|
||||
self.gotOneClose = False
|
||||
self.gotClosed = False
|
||||
|
||||
def openFailed(self, reason):
|
||||
"""
|
||||
Opening the channel failed. Store the reason why.
|
||||
"""
|
||||
self.openFailureReason = reason
|
||||
|
||||
def request_test(self, data):
|
||||
"""
|
||||
A test request. Return True if data is 'data'.
|
||||
|
||||
@type data: C{str}
|
||||
"""
|
||||
self.numberRequests += 1
|
||||
return data == 'data'
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Data was received. Store it in the buffer.
|
||||
"""
|
||||
self.inBuffer.append(data)
|
||||
|
||||
def extReceived(self, code, data):
|
||||
"""
|
||||
Extended data was received. Store it in the buffer.
|
||||
"""
|
||||
self.extBuffer.append((code, data))
|
||||
|
||||
def eofReceived(self):
|
||||
"""
|
||||
EOF was received. Remember it.
|
||||
"""
|
||||
self.gotEOF = True
|
||||
|
||||
def closeReceived(self):
|
||||
"""
|
||||
Close was received. Remember it.
|
||||
"""
|
||||
self.gotOneClose = True
|
||||
|
||||
def closed(self):
|
||||
"""
|
||||
The channel is closed. Rembember it.
|
||||
"""
|
||||
self.gotClosed = True
|
||||
|
||||
class TestAvatar:
|
||||
"""
|
||||
A mocked-up version of twisted.conch.avatar.ConchUser
|
||||
"""
|
||||
_ARGS_ERROR_CODE = 123
|
||||
|
||||
def lookupChannel(self, channelType, windowSize, maxPacket, data):
|
||||
"""
|
||||
The server wants us to return a channel. If the requested channel is
|
||||
our TestChannel, return it, otherwise return None.
|
||||
"""
|
||||
if channelType == TestChannel.name:
|
||||
return TestChannel(remoteWindow=windowSize,
|
||||
remoteMaxPacket=maxPacket,
|
||||
data=data, avatar=self)
|
||||
elif channelType == "conch-error-args":
|
||||
# Raise a ConchError with backwards arguments to make sure the
|
||||
# connection fixes it for us. This case should be deprecated and
|
||||
# deleted eventually, but only after all of Conch gets the argument
|
||||
# order right.
|
||||
raise error.ConchError(
|
||||
self._ARGS_ERROR_CODE, "error args in wrong order")
|
||||
|
||||
|
||||
def gotGlobalRequest(self, requestType, data):
|
||||
"""
|
||||
The client has made a global request. If the global request is
|
||||
'TestGlobal', return True. If the global request is 'TestData',
|
||||
return True and the request-specific data we received. Otherwise,
|
||||
return False.
|
||||
"""
|
||||
if requestType == 'TestGlobal':
|
||||
return True
|
||||
elif requestType == 'TestData':
|
||||
return True, data
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
|
||||
class TestConnection(connection.SSHConnection):
|
||||
"""
|
||||
A subclass of SSHConnection for testing.
|
||||
|
||||
@ivar channel: the current channel.
|
||||
@type channel. C{TestChannel}
|
||||
"""
|
||||
|
||||
def logPrefix(self):
|
||||
return "TestConnection"
|
||||
|
||||
def global_TestGlobal(self, data):
|
||||
"""
|
||||
The other side made the 'TestGlobal' global request. Return True.
|
||||
"""
|
||||
return True
|
||||
|
||||
def global_Test_Data(self, data):
|
||||
"""
|
||||
The other side made the 'Test-Data' global request. Return True and
|
||||
the data we received.
|
||||
"""
|
||||
return True, data
|
||||
|
||||
def channel_TestChannel(self, windowSize, maxPacket, data):
|
||||
"""
|
||||
The other side is requesting the TestChannel. Create a C{TestChannel}
|
||||
instance, store it, and return it.
|
||||
"""
|
||||
self.channel = TestChannel(remoteWindow=windowSize,
|
||||
remoteMaxPacket=maxPacket, data=data)
|
||||
return self.channel
|
||||
|
||||
def channel_ErrorChannel(self, windowSize, maxPacket, data):
|
||||
"""
|
||||
The other side is requesting the ErrorChannel. Raise an exception.
|
||||
"""
|
||||
raise AssertionError('no such thing')
|
||||
|
||||
|
||||
|
||||
class ConnectionTests(unittest.TestCase):
|
||||
|
||||
if test_userauth.transport is None:
|
||||
skip = "Cannot run without both PyCrypto and pyasn1"
|
||||
|
||||
def setUp(self):
|
||||
self.transport = test_userauth.FakeTransport(None)
|
||||
self.transport.avatar = TestAvatar()
|
||||
self.conn = TestConnection()
|
||||
self.conn.transport = self.transport
|
||||
self.conn.serviceStarted()
|
||||
|
||||
def _openChannel(self, channel):
|
||||
"""
|
||||
Open the channel with the default connection.
|
||||
"""
|
||||
self.conn.openChannel(channel)
|
||||
self.transport.packets = self.transport.packets[:-1]
|
||||
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(struct.pack('>2L',
|
||||
channel.id, 255) + '\x00\x02\x00\x00\x00\x00\x80\x00')
|
||||
|
||||
def tearDown(self):
|
||||
self.conn.serviceStopped()
|
||||
|
||||
def test_linkAvatar(self):
|
||||
"""
|
||||
Test that the connection links itself to the avatar in the
|
||||
transport.
|
||||
"""
|
||||
self.assertIs(self.transport.avatar.conn, self.conn)
|
||||
|
||||
def test_serviceStopped(self):
|
||||
"""
|
||||
Test that serviceStopped() closes any open channels.
|
||||
"""
|
||||
channel1 = TestChannel()
|
||||
channel2 = TestChannel()
|
||||
self.conn.openChannel(channel1)
|
||||
self.conn.openChannel(channel2)
|
||||
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00' * 4)
|
||||
self.assertTrue(channel1.gotOpen)
|
||||
self.assertFalse(channel2.gotOpen)
|
||||
self.conn.serviceStopped()
|
||||
self.assertTrue(channel1.gotClosed)
|
||||
|
||||
def test_GLOBAL_REQUEST(self):
|
||||
"""
|
||||
Test that global request packets are dispatched to the global_*
|
||||
methods and the return values are translated into success or failure
|
||||
messages.
|
||||
"""
|
||||
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\xff')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_REQUEST_SUCCESS, '')])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestData') + '\xff' +
|
||||
'test data')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_REQUEST_SUCCESS, 'test data')])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestBad') + '\xff')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_REQUEST_FAILURE, '')])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_GLOBAL_REQUEST(common.NS('TestGlobal') + '\x00')
|
||||
self.assertEqual(self.transport.packets, [])
|
||||
|
||||
def test_REQUEST_SUCCESS(self):
|
||||
"""
|
||||
Test that global request success packets cause the Deferred to be
|
||||
called back.
|
||||
"""
|
||||
d = self.conn.sendGlobalRequest('request', 'data', True)
|
||||
self.conn.ssh_REQUEST_SUCCESS('data')
|
||||
def check(data):
|
||||
self.assertEqual(data, 'data')
|
||||
d.addCallback(check)
|
||||
d.addErrback(self.fail)
|
||||
return d
|
||||
|
||||
def test_REQUEST_FAILURE(self):
|
||||
"""
|
||||
Test that global request failure packets cause the Deferred to be
|
||||
erred back.
|
||||
"""
|
||||
d = self.conn.sendGlobalRequest('request', 'data', True)
|
||||
self.conn.ssh_REQUEST_FAILURE('data')
|
||||
def check(f):
|
||||
self.assertEqual(f.value.data, 'data')
|
||||
d.addCallback(self.fail)
|
||||
d.addErrback(check)
|
||||
return d
|
||||
|
||||
def test_CHANNEL_OPEN(self):
|
||||
"""
|
||||
Test that open channel packets cause a channel to be created and
|
||||
opened or a failure message to be returned.
|
||||
"""
|
||||
del self.transport.avatar
|
||||
self.conn.ssh_CHANNEL_OPEN(common.NS('TestChannel') +
|
||||
'\x00\x00\x00\x01' * 4)
|
||||
self.assertTrue(self.conn.channel.gotOpen)
|
||||
self.assertEqual(self.conn.channel.conn, self.conn)
|
||||
self.assertEqual(self.conn.channel.data, '\x00\x00\x00\x01')
|
||||
self.assertEqual(self.conn.channel.specificData, '\x00\x00\x00\x01')
|
||||
self.assertEqual(self.conn.channel.remoteWindowLeft, 1)
|
||||
self.assertEqual(self.conn.channel.remoteMaxPacket, 1)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN_CONFIRMATION,
|
||||
'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00'
|
||||
'\x00\x00\x80\x00')])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_CHANNEL_OPEN(common.NS('BadChannel') +
|
||||
'\x00\x00\x00\x02' * 4)
|
||||
self.flushLoggedErrors()
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN_FAILURE,
|
||||
'\x00\x00\x00\x02\x00\x00\x00\x03' + common.NS(
|
||||
'unknown channel') + common.NS(''))])
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_CHANNEL_OPEN(common.NS('ErrorChannel') +
|
||||
'\x00\x00\x00\x02' * 4)
|
||||
self.flushLoggedErrors()
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN_FAILURE,
|
||||
'\x00\x00\x00\x02\x00\x00\x00\x02' + common.NS(
|
||||
'unknown failure') + common.NS(''))])
|
||||
|
||||
|
||||
def _lookupChannelErrorTest(self, code):
|
||||
"""
|
||||
Deliver a request for a channel open which will result in an exception
|
||||
being raised during channel lookup. Assert that an error response is
|
||||
delivered as a result.
|
||||
"""
|
||||
self.transport.avatar._ARGS_ERROR_CODE = code
|
||||
self.conn.ssh_CHANNEL_OPEN(
|
||||
common.NS('conch-error-args') + '\x00\x00\x00\x01' * 4)
|
||||
errors = self.flushLoggedErrors(error.ConchError)
|
||||
self.assertEqual(
|
||||
len(errors), 1, "Expected one error, got: %r" % (errors,))
|
||||
self.assertEqual(errors[0].value.args, (123, "error args in wrong order"))
|
||||
self.assertEqual(
|
||||
self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN_FAILURE,
|
||||
# The response includes some bytes which identifying the
|
||||
# associated request, as well as the error code (7b in hex) and
|
||||
# the error message.
|
||||
'\x00\x00\x00\x01\x00\x00\x00\x7b' + common.NS(
|
||||
'error args in wrong order') + common.NS(''))])
|
||||
|
||||
|
||||
def test_lookupChannelError(self):
|
||||
"""
|
||||
If a C{lookupChannel} implementation raises L{error.ConchError} with the
|
||||
arguments in the wrong order, a C{MSG_CHANNEL_OPEN} failure is still
|
||||
sent in response to the message.
|
||||
|
||||
This is a temporary work-around until L{error.ConchError} is given
|
||||
better attributes and all of the Conch code starts constructing
|
||||
instances of it properly. Eventually this functionality should be
|
||||
deprecated and then removed.
|
||||
"""
|
||||
self._lookupChannelErrorTest(123)
|
||||
|
||||
|
||||
def test_lookupChannelErrorLongCode(self):
|
||||
"""
|
||||
Like L{test_lookupChannelError}, but for the case where the failure code
|
||||
is represented as a C{long} instead of a C{int}.
|
||||
"""
|
||||
self._lookupChannelErrorTest(123L)
|
||||
|
||||
|
||||
def test_CHANNEL_OPEN_CONFIRMATION(self):
|
||||
"""
|
||||
Test that channel open confirmation packets cause the channel to be
|
||||
notified that it's open.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self.conn.openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION('\x00\x00\x00\x00'*5)
|
||||
self.assertEqual(channel.remoteWindowLeft, 0)
|
||||
self.assertEqual(channel.remoteMaxPacket, 0)
|
||||
self.assertEqual(channel.specificData, '\x00\x00\x00\x00')
|
||||
self.assertEqual(self.conn.channelsToRemoteChannel[channel],
|
||||
0)
|
||||
self.assertEqual(self.conn.localToRemoteChannel[0], 0)
|
||||
|
||||
def test_CHANNEL_OPEN_FAILURE(self):
|
||||
"""
|
||||
Test that channel open failure packets cause the channel to be
|
||||
notified that its opening failed.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self.conn.openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_OPEN_FAILURE('\x00\x00\x00\x00\x00\x00\x00'
|
||||
'\x01' + common.NS('failure!'))
|
||||
self.assertEqual(channel.openFailureReason.args, ('failure!', 1))
|
||||
self.assertEqual(self.conn.channels.get(channel), None)
|
||||
|
||||
|
||||
def test_CHANNEL_WINDOW_ADJUST(self):
|
||||
"""
|
||||
Test that channel window adjust messages add bytes to the channel
|
||||
window.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
oldWindowSize = channel.remoteWindowLeft
|
||||
self.conn.ssh_CHANNEL_WINDOW_ADJUST('\x00\x00\x00\x00\x00\x00\x00'
|
||||
'\x01')
|
||||
self.assertEqual(channel.remoteWindowLeft, oldWindowSize + 1)
|
||||
|
||||
def test_CHANNEL_DATA(self):
|
||||
"""
|
||||
Test that channel data messages are passed up to the channel, or
|
||||
cause the channel to be closed if the data is too large.
|
||||
"""
|
||||
channel = TestChannel(localWindow=6, localMaxPacket=5)
|
||||
self._openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS('data'))
|
||||
self.assertEqual(channel.inBuffer, ['data'])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
|
||||
'\x00\x00\x00\x04')])
|
||||
self.transport.packets = []
|
||||
longData = 'a' * (channel.localWindowLeft + 1)
|
||||
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x00' + common.NS(longData))
|
||||
self.assertEqual(channel.inBuffer, ['data'])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
bigData = 'a' * (channel.localMaxPacket + 1)
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_CHANNEL_DATA('\x00\x00\x00\x01' + common.NS(bigData))
|
||||
self.assertEqual(channel.inBuffer, [])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
|
||||
def test_CHANNEL_EXTENDED_DATA(self):
|
||||
"""
|
||||
Test that channel extended data messages are passed up to the channel,
|
||||
or cause the channel to be closed if they're too big.
|
||||
"""
|
||||
channel = TestChannel(localWindow=6, localMaxPacket=5)
|
||||
self._openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
|
||||
'\x00' + common.NS('data'))
|
||||
self.assertEqual(channel.extBuffer, [(0, 'data')])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
|
||||
'\x00\x00\x00\x04')])
|
||||
self.transport.packets = []
|
||||
longData = 'a' * (channel.localWindowLeft + 1)
|
||||
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x00\x00\x00\x00'
|
||||
'\x00' + common.NS(longData))
|
||||
self.assertEqual(channel.extBuffer, [(0, 'data')])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
bigData = 'a' * (channel.localMaxPacket + 1)
|
||||
self.transport.packets = []
|
||||
self.conn.ssh_CHANNEL_EXTENDED_DATA('\x00\x00\x00\x01\x00\x00\x00'
|
||||
'\x00' + common.NS(bigData))
|
||||
self.assertEqual(channel.extBuffer, [])
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
|
||||
def test_CHANNEL_EOF(self):
|
||||
"""
|
||||
Test that channel eof messages are passed up to the channel.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_EOF('\x00\x00\x00\x00')
|
||||
self.assertTrue(channel.gotEOF)
|
||||
|
||||
def test_CHANNEL_CLOSE(self):
|
||||
"""
|
||||
Test that channel close messages are passed up to the channel. Also,
|
||||
test that channel.close() is called if both sides are closed when this
|
||||
message is received.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendClose(channel)
|
||||
self.conn.ssh_CHANNEL_CLOSE('\x00\x00\x00\x00')
|
||||
self.assertTrue(channel.gotOneClose)
|
||||
self.assertTrue(channel.gotClosed)
|
||||
|
||||
def test_CHANNEL_REQUEST_success(self):
|
||||
"""
|
||||
Test that channel requests that succeed send MSG_CHANNEL_SUCCESS.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS('test')
|
||||
+ '\x00')
|
||||
self.assertEqual(channel.numberRequests, 1)
|
||||
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
|
||||
'test') + '\xff' + 'data')
|
||||
def check(result):
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_SUCCESS, '\x00\x00\x00\xff')])
|
||||
d.addCallback(check)
|
||||
return d
|
||||
|
||||
def test_CHANNEL_REQUEST_failure(self):
|
||||
"""
|
||||
Test that channel requests that fail send MSG_CHANNEL_FAILURE.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
d = self.conn.ssh_CHANNEL_REQUEST('\x00\x00\x00\x00' + common.NS(
|
||||
'test') + '\xff')
|
||||
def check(result):
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_FAILURE, '\x00\x00\x00\xff'
|
||||
)])
|
||||
d.addCallback(self.fail)
|
||||
d.addErrback(check)
|
||||
return d
|
||||
|
||||
def test_CHANNEL_REQUEST_SUCCESS(self):
|
||||
"""
|
||||
Test that channel request success messages cause the Deferred to be
|
||||
called back.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
d = self.conn.sendRequest(channel, 'test', 'data', True)
|
||||
self.conn.ssh_CHANNEL_SUCCESS('\x00\x00\x00\x00')
|
||||
def check(result):
|
||||
self.assertTrue(result)
|
||||
return d
|
||||
|
||||
def test_CHANNEL_REQUEST_FAILURE(self):
|
||||
"""
|
||||
Test that channel request failure messages cause the Deferred to be
|
||||
erred back.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
d = self.conn.sendRequest(channel, 'test', '', True)
|
||||
self.conn.ssh_CHANNEL_FAILURE('\x00\x00\x00\x00')
|
||||
def check(result):
|
||||
self.assertEqual(result.value.value, 'channel request failed')
|
||||
d.addCallback(self.fail)
|
||||
d.addErrback(check)
|
||||
return d
|
||||
|
||||
def test_sendGlobalRequest(self):
|
||||
"""
|
||||
Test that global request messages are sent in the right format.
|
||||
"""
|
||||
d = self.conn.sendGlobalRequest('wantReply', 'data', True)
|
||||
# must be added to prevent errbacking during teardown
|
||||
d.addErrback(lambda failure: None)
|
||||
self.conn.sendGlobalRequest('noReply', '', False)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_GLOBAL_REQUEST, common.NS('wantReply') +
|
||||
'\xffdata'),
|
||||
(connection.MSG_GLOBAL_REQUEST, common.NS('noReply') +
|
||||
'\x00')])
|
||||
self.assertEqual(self.conn.deferreds, {'global':[d]})
|
||||
|
||||
def test_openChannel(self):
|
||||
"""
|
||||
Test that open channel messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self.conn.openChannel(channel, 'aaaa')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_OPEN, common.NS('TestChannel') +
|
||||
'\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa')])
|
||||
self.assertEqual(channel.id, 0)
|
||||
self.assertEqual(self.conn.localChannelID, 1)
|
||||
|
||||
def test_sendRequest(self):
|
||||
"""
|
||||
Test that channel request messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
d = self.conn.sendRequest(channel, 'test', 'test', True)
|
||||
# needed to prevent errbacks during teardown.
|
||||
d.addErrback(lambda failure: None)
|
||||
self.conn.sendRequest(channel, 'test2', '', False)
|
||||
channel.localClosed = True # emulate sending a close message
|
||||
self.conn.sendRequest(channel, 'test3', '', True)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
|
||||
common.NS('test') + '\x01test'),
|
||||
(connection.MSG_CHANNEL_REQUEST, '\x00\x00\x00\xff' +
|
||||
common.NS('test2') + '\x00')])
|
||||
self.assertEqual(self.conn.deferreds[0], [d])
|
||||
|
||||
def test_adjustWindow(self):
|
||||
"""
|
||||
Test that channel window adjust messages cause bytes to be added
|
||||
to the window.
|
||||
"""
|
||||
channel = TestChannel(localWindow=5)
|
||||
self._openChannel(channel)
|
||||
channel.localWindowLeft = 0
|
||||
self.conn.adjustWindow(channel, 1)
|
||||
self.assertEqual(channel.localWindowLeft, 1)
|
||||
channel.localClosed = True
|
||||
self.conn.adjustWindow(channel, 2)
|
||||
self.assertEqual(channel.localWindowLeft, 1)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_WINDOW_ADJUST, '\x00\x00\x00\xff'
|
||||
'\x00\x00\x00\x01')])
|
||||
|
||||
def test_sendData(self):
|
||||
"""
|
||||
Test that channel data messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendData(channel, 'a')
|
||||
channel.localClosed = True
|
||||
self.conn.sendData(channel, 'b')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_DATA, '\x00\x00\x00\xff' +
|
||||
common.NS('a'))])
|
||||
|
||||
def test_sendExtendedData(self):
|
||||
"""
|
||||
Test that channel extended data messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendExtendedData(channel, 1, 'test')
|
||||
channel.localClosed = True
|
||||
self.conn.sendExtendedData(channel, 2, 'test2')
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_EXTENDED_DATA, '\x00\x00\x00\xff' +
|
||||
'\x00\x00\x00\x01' + common.NS('test'))])
|
||||
|
||||
def test_sendEOF(self):
|
||||
"""
|
||||
Test that channel EOF messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendEOF(channel)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
|
||||
channel.localClosed = True
|
||||
self.conn.sendEOF(channel)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_EOF, '\x00\x00\x00\xff')])
|
||||
|
||||
def test_sendClose(self):
|
||||
"""
|
||||
Test that channel close messages are sent in the right format.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
self.conn.sendClose(channel)
|
||||
self.assertTrue(channel.localClosed)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
self.conn.sendClose(channel)
|
||||
self.assertEqual(self.transport.packets,
|
||||
[(connection.MSG_CHANNEL_CLOSE, '\x00\x00\x00\xff')])
|
||||
|
||||
channel2 = TestChannel()
|
||||
self._openChannel(channel2)
|
||||
channel2.remoteClosed = True
|
||||
self.conn.sendClose(channel2)
|
||||
self.assertTrue(channel2.gotClosed)
|
||||
|
||||
def test_getChannelWithAvatar(self):
|
||||
"""
|
||||
Test that getChannel dispatches to the avatar when an avatar is
|
||||
present. Correct functioning without the avatar is verified in
|
||||
test_CHANNEL_OPEN.
|
||||
"""
|
||||
channel = self.conn.getChannel('TestChannel', 50, 30, 'data')
|
||||
self.assertEqual(channel.data, 'data')
|
||||
self.assertEqual(channel.remoteWindowLeft, 50)
|
||||
self.assertEqual(channel.remoteMaxPacket, 30)
|
||||
self.assertRaises(error.ConchError, self.conn.getChannel,
|
||||
'BadChannel', 50, 30, 'data')
|
||||
|
||||
def test_gotGlobalRequestWithoutAvatar(self):
|
||||
"""
|
||||
Test that gotGlobalRequests dispatches to global_* without an avatar.
|
||||
"""
|
||||
del self.transport.avatar
|
||||
self.assertTrue(self.conn.gotGlobalRequest('TestGlobal', 'data'))
|
||||
self.assertEqual(self.conn.gotGlobalRequest('Test-Data', 'data'),
|
||||
(True, 'data'))
|
||||
self.assertFalse(self.conn.gotGlobalRequest('BadGlobal', 'data'))
|
||||
|
||||
|
||||
def test_channelClosedCausesLeftoverChannelDeferredsToErrback(self):
|
||||
"""
|
||||
Whenever an SSH channel gets closed any Deferred that was returned by a
|
||||
sendRequest() on its parent connection must be errbacked.
|
||||
"""
|
||||
channel = TestChannel()
|
||||
self._openChannel(channel)
|
||||
|
||||
d = self.conn.sendRequest(
|
||||
channel, "dummyrequest", "dummydata", wantReply=1)
|
||||
d = self.assertFailure(d, error.ConchError)
|
||||
self.conn.channelClosed(channel)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
class CleanConnectionShutdownTests(unittest.TestCase):
|
||||
"""
|
||||
Check whether correct cleanup is performed on connection shutdown.
|
||||
"""
|
||||
if test_userauth.transport is None:
|
||||
skip = "Cannot run without both PyCrypto and pyasn1"
|
||||
|
||||
def setUp(self):
|
||||
self.transport = test_userauth.FakeTransport(None)
|
||||
self.transport.avatar = TestAvatar()
|
||||
self.conn = TestConnection()
|
||||
self.conn.transport = self.transport
|
||||
|
||||
|
||||
def test_serviceStoppedCausesLeftoverGlobalDeferredsToErrback(self):
|
||||
"""
|
||||
Once the service is stopped any leftover global deferred returned by
|
||||
a sendGlobalRequest() call must be errbacked.
|
||||
"""
|
||||
self.conn.serviceStarted()
|
||||
|
||||
d = self.conn.sendGlobalRequest(
|
||||
"dummyrequest", "dummydata", wantReply=1)
|
||||
d = self.assertFailure(d, error.ConchError)
|
||||
self.conn.serviceStopped()
|
||||
return d
|
||||
|
||||
|
@ -0,0 +1,169 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.client.default}.
|
||||
"""
|
||||
from twisted.python.reflect import requireModule
|
||||
|
||||
if requireModule('Crypto.Cipher.DES3') and requireModule('pyasn1'):
|
||||
from twisted.conch.client.agent import SSHAgentClient
|
||||
from twisted.conch.client.default import SSHUserAuthClient
|
||||
from twisted.conch.client.options import ConchOptions
|
||||
from twisted.conch.ssh.keys import Key
|
||||
else:
|
||||
skip = "PyCrypto and PyASN1 required for twisted.conch.client.default."
|
||||
|
||||
from twisted.trial.unittest import TestCase
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.conch.test import keydata
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
|
||||
|
||||
class SSHUserAuthClientTests(TestCase):
|
||||
"""
|
||||
Tests for L{SSHUserAuthClient}.
|
||||
|
||||
@type rsaPublic: L{Key}
|
||||
@ivar rsaPublic: A public RSA key.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.rsaPublic = Key.fromString(keydata.publicRSA_openssh)
|
||||
self.tmpdir = FilePath(self.mktemp())
|
||||
self.tmpdir.makedirs()
|
||||
self.rsaFile = self.tmpdir.child('id_rsa')
|
||||
self.rsaFile.setContent(keydata.privateRSA_openssh)
|
||||
self.tmpdir.child('id_rsa.pub').setContent(keydata.publicRSA_openssh)
|
||||
|
||||
|
||||
def test_signDataWithAgent(self):
|
||||
"""
|
||||
When connected to an agent, L{SSHUserAuthClient} can use it to
|
||||
request signatures of particular data with a particular L{Key}.
|
||||
"""
|
||||
client = SSHUserAuthClient("user", ConchOptions(), None)
|
||||
agent = SSHAgentClient()
|
||||
transport = StringTransport()
|
||||
agent.makeConnection(transport)
|
||||
client.keyAgent = agent
|
||||
cleartext = "Sign here"
|
||||
client.signData(self.rsaPublic, cleartext)
|
||||
self.assertEqual(
|
||||
transport.value(),
|
||||
"\x00\x00\x00\x8b\r\x00\x00\x00u" + self.rsaPublic.blob() +
|
||||
"\x00\x00\x00\t" + cleartext +
|
||||
"\x00\x00\x00\x00")
|
||||
|
||||
|
||||
def test_agentGetPublicKey(self):
|
||||
"""
|
||||
L{SSHUserAuthClient} looks up public keys from the agent using the
|
||||
L{SSHAgentClient} class. That L{SSHAgentClient.getPublicKey} returns a
|
||||
L{Key} object with one of the public keys in the agent. If no more
|
||||
keys are present, it returns C{None}.
|
||||
"""
|
||||
agent = SSHAgentClient()
|
||||
agent.blobs = [self.rsaPublic.blob()]
|
||||
key = agent.getPublicKey()
|
||||
self.assertEqual(key.isPublic(), True)
|
||||
self.assertEqual(key, self.rsaPublic)
|
||||
self.assertEqual(agent.getPublicKey(), None)
|
||||
|
||||
|
||||
def test_getPublicKeyFromFile(self):
|
||||
"""
|
||||
L{SSHUserAuthClient.getPublicKey()} is able to get a public key from
|
||||
the first file described by its options' C{identitys} list, and return
|
||||
the corresponding public L{Key} object.
|
||||
"""
|
||||
options = ConchOptions()
|
||||
options.identitys = [self.rsaFile.path]
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
key = client.getPublicKey()
|
||||
self.assertEqual(key.isPublic(), True)
|
||||
self.assertEqual(key, self.rsaPublic)
|
||||
|
||||
|
||||
def test_getPublicKeyAgentFallback(self):
|
||||
"""
|
||||
If an agent is present, but doesn't return a key,
|
||||
L{SSHUserAuthClient.getPublicKey} continue with the normal key lookup.
|
||||
"""
|
||||
options = ConchOptions()
|
||||
options.identitys = [self.rsaFile.path]
|
||||
agent = SSHAgentClient()
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
client.keyAgent = agent
|
||||
key = client.getPublicKey()
|
||||
self.assertEqual(key.isPublic(), True)
|
||||
self.assertEqual(key, self.rsaPublic)
|
||||
|
||||
|
||||
def test_getPublicKeyBadKeyError(self):
|
||||
"""
|
||||
If L{keys.Key.fromFile} raises a L{keys.BadKeyError}, the
|
||||
L{SSHUserAuthClient.getPublicKey} tries again to get a public key by
|
||||
calling itself recursively.
|
||||
"""
|
||||
options = ConchOptions()
|
||||
self.tmpdir.child('id_dsa.pub').setContent(keydata.publicDSA_openssh)
|
||||
dsaFile = self.tmpdir.child('id_dsa')
|
||||
dsaFile.setContent(keydata.privateDSA_openssh)
|
||||
options.identitys = [self.rsaFile.path, dsaFile.path]
|
||||
self.tmpdir.child('id_rsa.pub').setContent('not a key!')
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
key = client.getPublicKey()
|
||||
self.assertEqual(key.isPublic(), True)
|
||||
self.assertEqual(key, Key.fromString(keydata.publicDSA_openssh))
|
||||
self.assertEqual(client.usedFiles, [self.rsaFile.path, dsaFile.path])
|
||||
|
||||
|
||||
def test_getPrivateKey(self):
|
||||
"""
|
||||
L{SSHUserAuthClient.getPrivateKey} will load a private key from the
|
||||
last used file populated by L{SSHUserAuthClient.getPublicKey}, and
|
||||
return a L{Deferred} which fires with the corresponding private L{Key}.
|
||||
"""
|
||||
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
|
||||
options = ConchOptions()
|
||||
options.identitys = [self.rsaFile.path]
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
# Populate the list of used files
|
||||
client.getPublicKey()
|
||||
|
||||
def _cbGetPrivateKey(key):
|
||||
self.assertEqual(key.isPublic(), False)
|
||||
self.assertEqual(key, rsaPrivate)
|
||||
|
||||
return client.getPrivateKey().addCallback(_cbGetPrivateKey)
|
||||
|
||||
|
||||
def test_getPrivateKeyPassphrase(self):
|
||||
"""
|
||||
L{SSHUserAuthClient} can get a private key from a file, and return a
|
||||
Deferred called back with a private L{Key} object, even if the key is
|
||||
encrypted.
|
||||
"""
|
||||
rsaPrivate = Key.fromString(keydata.privateRSA_openssh)
|
||||
passphrase = 'this is the passphrase'
|
||||
self.rsaFile.setContent(rsaPrivate.toString('openssh', passphrase))
|
||||
options = ConchOptions()
|
||||
options.identitys = [self.rsaFile.path]
|
||||
client = SSHUserAuthClient("user", options, None)
|
||||
# Populate the list of used files
|
||||
client.getPublicKey()
|
||||
|
||||
def _getPassword(prompt):
|
||||
self.assertEqual(prompt,
|
||||
"Enter passphrase for key '%s': " % (
|
||||
self.rsaFile.path,))
|
||||
return passphrase
|
||||
|
||||
def _cbGetPrivateKey(key):
|
||||
self.assertEqual(key.isPublic(), False)
|
||||
self.assertEqual(key, rsaPrivate)
|
||||
|
||||
self.patch(client, '_getPassword', _getPassword)
|
||||
return client.getPrivateKey().addCallback(_cbGetPrivateKey)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,770 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_filetransfer -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE file for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.ssh.filetransfer}.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
|
||||
from twisted.trial import unittest
|
||||
try:
|
||||
from twisted.conch import unix
|
||||
unix # shut up pyflakes
|
||||
except ImportError:
|
||||
unix = None
|
||||
|
||||
from twisted.conch import avatar
|
||||
from twisted.conch.ssh import common, connection, filetransfer, session
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols import loopback
|
||||
from twisted.python import components
|
||||
|
||||
|
||||
class TestAvatar(avatar.ConchUser):
|
||||
def __init__(self):
|
||||
avatar.ConchUser.__init__(self)
|
||||
self.channelLookup['session'] = session.SSHSession
|
||||
self.subsystemLookup['sftp'] = filetransfer.FileTransferServer
|
||||
|
||||
def _runAsUser(self, f, *args, **kw):
|
||||
try:
|
||||
f = iter(f)
|
||||
except TypeError:
|
||||
f = [(f, args, kw)]
|
||||
for i in f:
|
||||
func = i[0]
|
||||
args = len(i)>1 and i[1] or ()
|
||||
kw = len(i)>2 and i[2] or {}
|
||||
r = func(*args, **kw)
|
||||
return r
|
||||
|
||||
|
||||
class FileTransferTestAvatar(TestAvatar):
|
||||
|
||||
def __init__(self, homeDir):
|
||||
TestAvatar.__init__(self)
|
||||
self.homeDir = homeDir
|
||||
|
||||
def getHomeDir(self):
|
||||
return os.path.join(os.getcwd(), self.homeDir)
|
||||
|
||||
|
||||
class ConchSessionForTestAvatar:
|
||||
|
||||
def __init__(self, avatar):
|
||||
self.avatar = avatar
|
||||
|
||||
if unix:
|
||||
if not hasattr(unix, 'SFTPServerForUnixConchUser'):
|
||||
# unix should either be a fully working module, or None. I'm not sure
|
||||
# how this happens, but on win32 it does. Try to cope. --spiv.
|
||||
import warnings
|
||||
warnings.warn(("twisted.conch.unix imported %r, "
|
||||
"but doesn't define SFTPServerForUnixConchUser'")
|
||||
% (unix,))
|
||||
unix = None
|
||||
else:
|
||||
class FileTransferForTestAvatar(unix.SFTPServerForUnixConchUser):
|
||||
|
||||
def gotVersion(self, version, otherExt):
|
||||
return {'conchTest' : 'ext data'}
|
||||
|
||||
def extendedRequest(self, extName, extData):
|
||||
if extName == 'testExtendedRequest':
|
||||
return 'bar'
|
||||
raise NotImplementedError
|
||||
|
||||
components.registerAdapter(FileTransferForTestAvatar,
|
||||
TestAvatar,
|
||||
filetransfer.ISFTPServer)
|
||||
|
||||
class SFTPTestBase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.testDir = self.mktemp()
|
||||
# Give the testDir another level so we can safely "cd .." from it in
|
||||
# tests.
|
||||
self.testDir = os.path.join(self.testDir, 'extra')
|
||||
os.makedirs(os.path.join(self.testDir, 'testDirectory'))
|
||||
|
||||
f = file(os.path.join(self.testDir, 'testfile1'),'w')
|
||||
f.write('a'*10+'b'*10)
|
||||
f.write(file('/dev/urandom').read(1024*64)) # random data
|
||||
os.chmod(os.path.join(self.testDir, 'testfile1'), 0644)
|
||||
file(os.path.join(self.testDir, 'testRemoveFile'), 'w').write('a')
|
||||
file(os.path.join(self.testDir, 'testRenameFile'), 'w').write('a')
|
||||
file(os.path.join(self.testDir, '.testHiddenFile'), 'w').write('a')
|
||||
|
||||
|
||||
class OurServerOurClientTests(SFTPTestBase):
|
||||
|
||||
if not unix:
|
||||
skip = "can't run on non-posix computers"
|
||||
|
||||
def setUp(self):
|
||||
SFTPTestBase.setUp(self)
|
||||
|
||||
self.avatar = FileTransferTestAvatar(self.testDir)
|
||||
self.server = filetransfer.FileTransferServer(avatar=self.avatar)
|
||||
clientTransport = loopback.LoopbackRelay(self.server)
|
||||
|
||||
self.client = filetransfer.FileTransferClient()
|
||||
self._serverVersion = None
|
||||
self._extData = None
|
||||
def _(serverVersion, extData):
|
||||
self._serverVersion = serverVersion
|
||||
self._extData = extData
|
||||
self.client.gotServerVersion = _
|
||||
serverTransport = loopback.LoopbackRelay(self.client)
|
||||
self.client.makeConnection(clientTransport)
|
||||
self.server.makeConnection(serverTransport)
|
||||
|
||||
self.clientTransport = clientTransport
|
||||
self.serverTransport = serverTransport
|
||||
|
||||
self._emptyBuffers()
|
||||
|
||||
|
||||
def _emptyBuffers(self):
|
||||
while self.serverTransport.buffer or self.clientTransport.buffer:
|
||||
self.serverTransport.clearBuffer()
|
||||
self.clientTransport.clearBuffer()
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
self.serverTransport.loseConnection()
|
||||
self.clientTransport.loseConnection()
|
||||
self.serverTransport.clearBuffer()
|
||||
self.clientTransport.clearBuffer()
|
||||
|
||||
|
||||
def testServerVersion(self):
|
||||
self.assertEqual(self._serverVersion, 3)
|
||||
self.assertEqual(self._extData, {'conchTest' : 'ext data'})
|
||||
|
||||
|
||||
def test_interface_implementation(self):
|
||||
"""
|
||||
It implements the ISFTPServer interface.
|
||||
"""
|
||||
self.assertTrue(
|
||||
filetransfer.ISFTPServer.providedBy(self.server.client),
|
||||
"ISFTPServer not provided by %r" % (self.server.client,))
|
||||
|
||||
|
||||
def test_openedFileClosedWithConnection(self):
|
||||
"""
|
||||
A file opened with C{openFile} is close when the connection is lost.
|
||||
"""
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
oldClose = os.close
|
||||
closed = []
|
||||
def close(fd):
|
||||
closed.append(fd)
|
||||
oldClose(fd)
|
||||
|
||||
self.patch(os, "close", close)
|
||||
|
||||
def _fileOpened(openFile):
|
||||
fd = self.server.openFiles[openFile.handle[4:]].fd
|
||||
self.serverTransport.loseConnection()
|
||||
self.clientTransport.loseConnection()
|
||||
self.serverTransport.clearBuffer()
|
||||
self.clientTransport.clearBuffer()
|
||||
self.assertEqual(self.server.openFiles, {})
|
||||
self.assertIn(fd, closed)
|
||||
|
||||
d.addCallback(_fileOpened)
|
||||
return d
|
||||
|
||||
|
||||
def test_openedDirectoryClosedWithConnection(self):
|
||||
"""
|
||||
A directory opened with C{openDirectory} is close when the connection
|
||||
is lost.
|
||||
"""
|
||||
d = self.client.openDirectory('')
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getFiles(openDir):
|
||||
self.serverTransport.loseConnection()
|
||||
self.clientTransport.loseConnection()
|
||||
self.serverTransport.clearBuffer()
|
||||
self.clientTransport.clearBuffer()
|
||||
self.assertEqual(self.server.openDirs, {})
|
||||
|
||||
d.addCallback(_getFiles)
|
||||
return d
|
||||
|
||||
|
||||
def testOpenFileIO(self):
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _fileOpened(openFile):
|
||||
self.assertEqual(openFile, filetransfer.ISFTPFile(openFile))
|
||||
d = _readChunk(openFile)
|
||||
d.addCallback(_writeChunk, openFile)
|
||||
return d
|
||||
|
||||
def _readChunk(openFile):
|
||||
d = openFile.readChunk(0, 20)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, 'a'*10 + 'b'*10)
|
||||
return d
|
||||
|
||||
def _writeChunk(_, openFile):
|
||||
d = openFile.writeChunk(20, 'c'*10)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_readChunk2, openFile)
|
||||
return d
|
||||
|
||||
def _readChunk2(_, openFile):
|
||||
d = openFile.readChunk(0, 30)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, 'a'*10 + 'b'*10 + 'c'*10)
|
||||
return d
|
||||
|
||||
d.addCallback(_fileOpened)
|
||||
return d
|
||||
|
||||
def testClosedFileGetAttrs(self):
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getAttrs(_, openFile):
|
||||
d = openFile.getAttrs()
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
def _err(f):
|
||||
self.flushLoggedErrors()
|
||||
return f
|
||||
|
||||
def _close(openFile):
|
||||
d = openFile.close()
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_getAttrs, openFile)
|
||||
d.addErrback(_err)
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
d.addCallback(_close)
|
||||
return d
|
||||
|
||||
def testOpenFileAttributes(self):
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getAttrs(openFile):
|
||||
d = openFile.getAttrs()
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_getAttrs2)
|
||||
return d
|
||||
|
||||
def _getAttrs2(attrs1):
|
||||
d = self.client.getAttrs('testfile1')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, attrs1)
|
||||
return d
|
||||
|
||||
return d.addCallback(_getAttrs)
|
||||
|
||||
|
||||
def testOpenFileSetAttrs(self):
|
||||
# XXX test setAttrs
|
||||
# Ok, how about this for a start? It caught a bug :) -- spiv.
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getAttrs(openFile):
|
||||
d = openFile.getAttrs()
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_setAttrs)
|
||||
return d
|
||||
|
||||
def _setAttrs(attrs):
|
||||
attrs['atime'] = 0
|
||||
d = self.client.setAttrs('testfile1', attrs)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_getAttrs2)
|
||||
d.addCallback(self.assertEqual, attrs)
|
||||
return d
|
||||
|
||||
def _getAttrs2(_):
|
||||
d = self.client.getAttrs('testfile1')
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
d.addCallback(_getAttrs)
|
||||
return d
|
||||
|
||||
|
||||
def test_openFileExtendedAttributes(self):
|
||||
"""
|
||||
Check that L{filetransfer.FileTransferClient.openFile} can send
|
||||
extended attributes, that should be extracted server side. By default,
|
||||
they are ignored, so we just verify they are correctly parsed.
|
||||
"""
|
||||
savedAttributes = {}
|
||||
oldOpenFile = self.server.client.openFile
|
||||
def openFile(filename, flags, attrs):
|
||||
savedAttributes.update(attrs)
|
||||
return oldOpenFile(filename, flags, attrs)
|
||||
self.server.client.openFile = openFile
|
||||
|
||||
d = self.client.openFile("testfile1", filetransfer.FXF_READ |
|
||||
filetransfer.FXF_WRITE, {"ext_foo": "bar"})
|
||||
self._emptyBuffers()
|
||||
|
||||
def check(ign):
|
||||
self.assertEqual(savedAttributes, {"ext_foo": "bar"})
|
||||
|
||||
return d.addCallback(check)
|
||||
|
||||
|
||||
def testRemoveFile(self):
|
||||
d = self.client.getAttrs("testRemoveFile")
|
||||
self._emptyBuffers()
|
||||
def _removeFile(ignored):
|
||||
d = self.client.removeFile("testRemoveFile")
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
d.addCallback(_removeFile)
|
||||
d.addCallback(_removeFile)
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
def testRenameFile(self):
|
||||
d = self.client.getAttrs("testRenameFile")
|
||||
self._emptyBuffers()
|
||||
def _rename(attrs):
|
||||
d = self.client.renameFile("testRenameFile", "testRenamedFile")
|
||||
self._emptyBuffers()
|
||||
d.addCallback(_testRenamed, attrs)
|
||||
return d
|
||||
def _testRenamed(_, attrs):
|
||||
d = self.client.getAttrs("testRenamedFile")
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, attrs)
|
||||
return d.addCallback(_rename)
|
||||
|
||||
def testDirectoryBad(self):
|
||||
d = self.client.getAttrs("testMakeDirectory")
|
||||
self._emptyBuffers()
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
def testDirectoryCreation(self):
|
||||
d = self.client.makeDirectory("testMakeDirectory", {})
|
||||
self._emptyBuffers()
|
||||
|
||||
def _getAttrs(_):
|
||||
d = self.client.getAttrs("testMakeDirectory")
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
# XXX not until version 4/5
|
||||
# self.assertEqual(filetransfer.FILEXFER_TYPE_DIRECTORY&attrs['type'],
|
||||
# filetransfer.FILEXFER_TYPE_DIRECTORY)
|
||||
|
||||
def _removeDirectory(_):
|
||||
d = self.client.removeDirectory("testMakeDirectory")
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
d.addCallback(_getAttrs)
|
||||
d.addCallback(_removeDirectory)
|
||||
d.addCallback(_getAttrs)
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
def testOpenDirectory(self):
|
||||
d = self.client.openDirectory('')
|
||||
self._emptyBuffers()
|
||||
files = []
|
||||
|
||||
def _getFiles(openDir):
|
||||
def append(f):
|
||||
files.append(f)
|
||||
return openDir
|
||||
d = defer.maybeDeferred(openDir.next)
|
||||
self._emptyBuffers()
|
||||
d.addCallback(append)
|
||||
d.addCallback(_getFiles)
|
||||
d.addErrback(_close, openDir)
|
||||
return d
|
||||
|
||||
def _checkFiles(ignored):
|
||||
fs = list(zip(*files)[0])
|
||||
fs.sort()
|
||||
self.assertEqual(fs,
|
||||
['.testHiddenFile', 'testDirectory',
|
||||
'testRemoveFile', 'testRenameFile',
|
||||
'testfile1'])
|
||||
|
||||
def _close(_, openDir):
|
||||
d = openDir.close()
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
|
||||
d.addCallback(_getFiles)
|
||||
d.addCallback(_checkFiles)
|
||||
return d
|
||||
|
||||
def testLinkDoesntExist(self):
|
||||
d = self.client.getAttrs('testLink')
|
||||
self._emptyBuffers()
|
||||
return self.assertFailure(d, filetransfer.SFTPError)
|
||||
|
||||
def testLinkSharesAttrs(self):
|
||||
d = self.client.makeLink('testLink', 'testfile1')
|
||||
self._emptyBuffers()
|
||||
def _getFirstAttrs(_):
|
||||
d = self.client.getAttrs('testLink', 1)
|
||||
self._emptyBuffers()
|
||||
return d
|
||||
def _getSecondAttrs(firstAttrs):
|
||||
d = self.client.getAttrs('testfile1')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, firstAttrs)
|
||||
return d
|
||||
d.addCallback(_getFirstAttrs)
|
||||
return d.addCallback(_getSecondAttrs)
|
||||
|
||||
def testLinkPath(self):
|
||||
d = self.client.makeLink('testLink', 'testfile1')
|
||||
self._emptyBuffers()
|
||||
def _readLink(_):
|
||||
d = self.client.readLink('testLink')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual,
|
||||
os.path.join(os.getcwd(), self.testDir, 'testfile1'))
|
||||
return d
|
||||
def _realPath(_):
|
||||
d = self.client.realPath('testLink')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual,
|
||||
os.path.join(os.getcwd(), self.testDir, 'testfile1'))
|
||||
return d
|
||||
d.addCallback(_readLink)
|
||||
d.addCallback(_realPath)
|
||||
return d
|
||||
|
||||
def testExtendedRequest(self):
|
||||
d = self.client.extendedRequest('testExtendedRequest', 'foo')
|
||||
self._emptyBuffers()
|
||||
d.addCallback(self.assertEqual, 'bar')
|
||||
d.addCallback(self._cbTestExtendedRequest)
|
||||
return d
|
||||
|
||||
def _cbTestExtendedRequest(self, ignored):
|
||||
d = self.client.extendedRequest('testBadRequest', '')
|
||||
self._emptyBuffers()
|
||||
return self.assertFailure(d, NotImplementedError)
|
||||
|
||||
|
||||
class FakeConn:
|
||||
def sendClose(self, channel):
|
||||
pass
|
||||
|
||||
|
||||
class FileTransferCloseTests(unittest.TestCase):
|
||||
|
||||
if not unix:
|
||||
skip = "can't run on non-posix computers"
|
||||
|
||||
def setUp(self):
|
||||
self.avatar = TestAvatar()
|
||||
|
||||
def buildServerConnection(self):
|
||||
# make a server connection
|
||||
conn = connection.SSHConnection()
|
||||
# server connections have a 'self.transport.avatar'.
|
||||
class DummyTransport:
|
||||
def __init__(self):
|
||||
self.transport = self
|
||||
def sendPacket(self, kind, data):
|
||||
pass
|
||||
def logPrefix(self):
|
||||
return 'dummy transport'
|
||||
conn.transport = DummyTransport()
|
||||
conn.transport.avatar = self.avatar
|
||||
return conn
|
||||
|
||||
def interceptConnectionLost(self, sftpServer):
|
||||
self.connectionLostFired = False
|
||||
origConnectionLost = sftpServer.connectionLost
|
||||
def connectionLost(reason):
|
||||
self.connectionLostFired = True
|
||||
origConnectionLost(reason)
|
||||
sftpServer.connectionLost = connectionLost
|
||||
|
||||
def assertSFTPConnectionLost(self):
|
||||
self.assertTrue(self.connectionLostFired,
|
||||
"sftpServer's connectionLost was not called")
|
||||
|
||||
def test_sessionClose(self):
|
||||
"""
|
||||
Closing a session should notify an SFTP subsystem launched by that
|
||||
session.
|
||||
"""
|
||||
# make a session
|
||||
testSession = session.SSHSession(conn=FakeConn(), avatar=self.avatar)
|
||||
|
||||
# start an SFTP subsystem on the session
|
||||
testSession.request_subsystem(common.NS('sftp'))
|
||||
sftpServer = testSession.client.transport.proto
|
||||
|
||||
# intercept connectionLost so we can check that it's called
|
||||
self.interceptConnectionLost(sftpServer)
|
||||
|
||||
# close session
|
||||
testSession.closeReceived()
|
||||
|
||||
self.assertSFTPConnectionLost()
|
||||
|
||||
def test_clientClosesChannelOnConnnection(self):
|
||||
"""
|
||||
A client sending CHANNEL_CLOSE should trigger closeReceived on the
|
||||
associated channel instance.
|
||||
"""
|
||||
conn = self.buildServerConnection()
|
||||
|
||||
# somehow get a session
|
||||
packet = common.NS('session') + struct.pack('>L', 0) * 3
|
||||
conn.ssh_CHANNEL_OPEN(packet)
|
||||
sessionChannel = conn.channels[0]
|
||||
|
||||
sessionChannel.request_subsystem(common.NS('sftp'))
|
||||
sftpServer = sessionChannel.client.transport.proto
|
||||
self.interceptConnectionLost(sftpServer)
|
||||
|
||||
# intercept closeReceived
|
||||
self.interceptConnectionLost(sftpServer)
|
||||
|
||||
# close the connection
|
||||
conn.ssh_CHANNEL_CLOSE(struct.pack('>L', 0))
|
||||
|
||||
self.assertSFTPConnectionLost()
|
||||
|
||||
|
||||
def test_stopConnectionServiceClosesChannel(self):
|
||||
"""
|
||||
Closing an SSH connection should close all sessions within it.
|
||||
"""
|
||||
conn = self.buildServerConnection()
|
||||
|
||||
# somehow get a session
|
||||
packet = common.NS('session') + struct.pack('>L', 0) * 3
|
||||
conn.ssh_CHANNEL_OPEN(packet)
|
||||
sessionChannel = conn.channels[0]
|
||||
|
||||
sessionChannel.request_subsystem(common.NS('sftp'))
|
||||
sftpServer = sessionChannel.client.transport.proto
|
||||
self.interceptConnectionLost(sftpServer)
|
||||
|
||||
# close the connection
|
||||
conn.serviceStopped()
|
||||
|
||||
self.assertSFTPConnectionLost()
|
||||
|
||||
|
||||
|
||||
class ConstantsTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the constants used by the SFTP protocol implementation.
|
||||
|
||||
@ivar filexferSpecExcerpts: Excerpts from the
|
||||
draft-ietf-secsh-filexfer-02.txt (draft) specification of the SFTP
|
||||
protocol. There are more recent drafts of the specification, but this
|
||||
one describes version 3, which is what conch (and OpenSSH) implements.
|
||||
"""
|
||||
|
||||
|
||||
filexferSpecExcerpts = [
|
||||
"""
|
||||
The following values are defined for packet types.
|
||||
|
||||
#define SSH_FXP_INIT 1
|
||||
#define SSH_FXP_VERSION 2
|
||||
#define SSH_FXP_OPEN 3
|
||||
#define SSH_FXP_CLOSE 4
|
||||
#define SSH_FXP_READ 5
|
||||
#define SSH_FXP_WRITE 6
|
||||
#define SSH_FXP_LSTAT 7
|
||||
#define SSH_FXP_FSTAT 8
|
||||
#define SSH_FXP_SETSTAT 9
|
||||
#define SSH_FXP_FSETSTAT 10
|
||||
#define SSH_FXP_OPENDIR 11
|
||||
#define SSH_FXP_READDIR 12
|
||||
#define SSH_FXP_REMOVE 13
|
||||
#define SSH_FXP_MKDIR 14
|
||||
#define SSH_FXP_RMDIR 15
|
||||
#define SSH_FXP_REALPATH 16
|
||||
#define SSH_FXP_STAT 17
|
||||
#define SSH_FXP_RENAME 18
|
||||
#define SSH_FXP_READLINK 19
|
||||
#define SSH_FXP_SYMLINK 20
|
||||
#define SSH_FXP_STATUS 101
|
||||
#define SSH_FXP_HANDLE 102
|
||||
#define SSH_FXP_DATA 103
|
||||
#define SSH_FXP_NAME 104
|
||||
#define SSH_FXP_ATTRS 105
|
||||
#define SSH_FXP_EXTENDED 200
|
||||
#define SSH_FXP_EXTENDED_REPLY 201
|
||||
|
||||
Additional packet types should only be defined if the protocol
|
||||
version number (see Section ``Protocol Initialization'') is
|
||||
incremented, and their use MUST be negotiated using the version
|
||||
number. However, the SSH_FXP_EXTENDED and SSH_FXP_EXTENDED_REPLY
|
||||
packets can be used to implement vendor-specific extensions. See
|
||||
Section ``Vendor-Specific-Extensions'' for more details.
|
||||
""",
|
||||
"""
|
||||
The flags bits are defined to have the following values:
|
||||
|
||||
#define SSH_FILEXFER_ATTR_SIZE 0x00000001
|
||||
#define SSH_FILEXFER_ATTR_UIDGID 0x00000002
|
||||
#define SSH_FILEXFER_ATTR_PERMISSIONS 0x00000004
|
||||
#define SSH_FILEXFER_ATTR_ACMODTIME 0x00000008
|
||||
#define SSH_FILEXFER_ATTR_EXTENDED 0x80000000
|
||||
|
||||
""",
|
||||
"""
|
||||
The `pflags' field is a bitmask. The following bits have been
|
||||
defined.
|
||||
|
||||
#define SSH_FXF_READ 0x00000001
|
||||
#define SSH_FXF_WRITE 0x00000002
|
||||
#define SSH_FXF_APPEND 0x00000004
|
||||
#define SSH_FXF_CREAT 0x00000008
|
||||
#define SSH_FXF_TRUNC 0x00000010
|
||||
#define SSH_FXF_EXCL 0x00000020
|
||||
""",
|
||||
"""
|
||||
Currently, the following values are defined (other values may be
|
||||
defined by future versions of this protocol):
|
||||
|
||||
#define SSH_FX_OK 0
|
||||
#define SSH_FX_EOF 1
|
||||
#define SSH_FX_NO_SUCH_FILE 2
|
||||
#define SSH_FX_PERMISSION_DENIED 3
|
||||
#define SSH_FX_FAILURE 4
|
||||
#define SSH_FX_BAD_MESSAGE 5
|
||||
#define SSH_FX_NO_CONNECTION 6
|
||||
#define SSH_FX_CONNECTION_LOST 7
|
||||
#define SSH_FX_OP_UNSUPPORTED 8
|
||||
"""]
|
||||
|
||||
|
||||
def test_constantsAgainstSpec(self):
|
||||
"""
|
||||
The constants used by the SFTP protocol implementation match those
|
||||
found by searching through the spec.
|
||||
"""
|
||||
constants = {}
|
||||
for excerpt in self.filexferSpecExcerpts:
|
||||
for line in excerpt.splitlines():
|
||||
m = re.match('^\s*#define SSH_([A-Z_]+)\s+([0-9x]*)\s*$', line)
|
||||
if m:
|
||||
constants[m.group(1)] = long(m.group(2), 0)
|
||||
self.assertTrue(
|
||||
len(constants) > 0, "No constants found (the test must be buggy).")
|
||||
for k, v in constants.items():
|
||||
self.assertEqual(v, getattr(filetransfer, k))
|
||||
|
||||
|
||||
|
||||
class RawPacketDataTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{filetransfer.FileTransferClient} which explicitly craft certain
|
||||
less common protocol messages to exercise their handling.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.ftc = filetransfer.FileTransferClient()
|
||||
|
||||
|
||||
def test_packetSTATUS(self):
|
||||
"""
|
||||
A STATUS packet containing a result code, a message, and a language is
|
||||
parsed to produce the result of an outstanding request L{Deferred}.
|
||||
|
||||
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
|
||||
of the SFTP Internet-Draft.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
d.addCallback(self._cbTestPacketSTATUS)
|
||||
self.ftc.openRequests[1] = d
|
||||
data = struct.pack('!LL', 1, filetransfer.FX_OK) + common.NS('msg') + common.NS('lang')
|
||||
self.ftc.packet_STATUS(data)
|
||||
return d
|
||||
|
||||
|
||||
def _cbTestPacketSTATUS(self, result):
|
||||
"""
|
||||
Assert that the result is a two-tuple containing the message and
|
||||
language from the STATUS packet.
|
||||
"""
|
||||
self.assertEqual(result[0], 'msg')
|
||||
self.assertEqual(result[1], 'lang')
|
||||
|
||||
|
||||
def test_packetSTATUSShort(self):
|
||||
"""
|
||||
A STATUS packet containing only a result code can also be parsed to
|
||||
produce the result of an outstanding request L{Deferred}. Such packets
|
||||
are sent by some SFTP implementations, though not strictly legal.
|
||||
|
||||
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
|
||||
of the SFTP Internet-Draft.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
d.addCallback(self._cbTestPacketSTATUSShort)
|
||||
self.ftc.openRequests[1] = d
|
||||
data = struct.pack('!LL', 1, filetransfer.FX_OK)
|
||||
self.ftc.packet_STATUS(data)
|
||||
return d
|
||||
|
||||
|
||||
def _cbTestPacketSTATUSShort(self, result):
|
||||
"""
|
||||
Assert that the result is a two-tuple containing empty strings, since
|
||||
the STATUS packet had neither a message nor a language.
|
||||
"""
|
||||
self.assertEqual(result[0], '')
|
||||
self.assertEqual(result[1], '')
|
||||
|
||||
|
||||
def test_packetSTATUSWithoutLang(self):
|
||||
"""
|
||||
A STATUS packet containing a result code and a message but no language
|
||||
can also be parsed to produce the result of an outstanding request
|
||||
L{Deferred}. Such packets are sent by some SFTP implementations, though
|
||||
not strictly legal.
|
||||
|
||||
@see: U{section 9.1<http://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-9.1>}
|
||||
of the SFTP Internet-Draft.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
d.addCallback(self._cbTestPacketSTATUSWithoutLang)
|
||||
self.ftc.openRequests[1] = d
|
||||
data = struct.pack('!LL', 1, filetransfer.FX_OK) + common.NS('msg')
|
||||
self.ftc.packet_STATUS(data)
|
||||
return d
|
||||
|
||||
|
||||
def _cbTestPacketSTATUSWithoutLang(self, result):
|
||||
"""
|
||||
Assert that the result is a two-tuple containing the message from the
|
||||
STATUS packet and an empty string, since the language was missing.
|
||||
"""
|
||||
self.assertEqual(result[0], 'msg')
|
||||
self.assertEqual(result[1], '')
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.conch.ssh.forwarding}.
|
||||
"""
|
||||
|
||||
from socket import AF_INET6
|
||||
|
||||
from twisted.conch.ssh import forwarding
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.address import IPv6Address
|
||||
from twisted.trial import unittest
|
||||
from twisted.test.proto_helpers import MemoryReactorClock, StringTransport
|
||||
|
||||
|
||||
class TestSSHConnectForwardingChannel(unittest.TestCase):
|
||||
"""
|
||||
Unit and integration tests for L{SSHConnectForwardingChannel}.
|
||||
"""
|
||||
|
||||
def patchHostnameEndpointResolver(self, request, response):
|
||||
"""
|
||||
Patch L{forwarding.HostnameEndpoint} to respond with a predefined
|
||||
answer for DNS resolver requests.
|
||||
|
||||
@param request: Tupple of requested (hostname, port).
|
||||
@type request: C{tuppe}.
|
||||
|
||||
@param response: Tupple of (family, address) to respond the the
|
||||
associated C{request}.
|
||||
@type response: C{tuppe}.
|
||||
"""
|
||||
hostname, port = request
|
||||
family, address = response
|
||||
riggerResolver = {('fwd.example.org', 1234): (
|
||||
AF_INET6, None, None, None, ('::1', 1234))}
|
||||
|
||||
def riggedResolution(this, host, port):
|
||||
return defer.succeed([riggerResolver[(host, port)]])
|
||||
self.patch(
|
||||
forwarding.HostnameEndpoint, '_nameResolution', riggedResolution)
|
||||
|
||||
|
||||
def makeTCPConnection(self, reactor):
|
||||
"""
|
||||
Fake that connection was established for first connectTCP request made
|
||||
on C{reactor}.
|
||||
|
||||
@param reactor: Reactor on which to fake the connection.
|
||||
@type reactor: A reactor.
|
||||
"""
|
||||
factory = reactor.tcpClients[0][2]
|
||||
connector = reactor.connectors[0]
|
||||
protocol = factory.buildProtocol(None)
|
||||
transport = StringTransport(peerAddress=connector.getDestination())
|
||||
protocol.makeConnection(transport)
|
||||
|
||||
|
||||
def test_channelOpenHostnameRequests(self):
|
||||
"""
|
||||
When a hostname is sent as part of forwarding requests, it
|
||||
is resolved using HostnameEndpoint's resolver.
|
||||
"""
|
||||
sut = forwarding.SSHConnectForwardingChannel(
|
||||
hostport=('fwd.example.org', 1234))
|
||||
# Patch channel and resolver to not touch the network.
|
||||
sut._reactor = MemoryReactorClock()
|
||||
self.patchHostnameEndpointResolver(
|
||||
request=('fwd.example.org', 1234),
|
||||
response=(AF_INET6 ,'::1'),
|
||||
)
|
||||
|
||||
sut.channelOpen(None)
|
||||
|
||||
self.makeTCPConnection(sut._reactor)
|
||||
self.successResultOf(sut._channelOpenDeferred)
|
||||
# Channel is connected using a forwarding client to the resolved
|
||||
# address of the requested host.
|
||||
self.assertTrue(isinstance(sut.client, forwarding.SSHForwardingClient))
|
||||
self.assertEqual(
|
||||
IPv6Address('TCP', '::1', 1234), sut.client.transport.getPeer())
|
@ -0,0 +1,614 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_helper -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
from twisted.conch.insults import helper
|
||||
from twisted.conch.insults.insults import G0, G1, G2, G3
|
||||
from twisted.conch.insults.insults import modes, privateModes
|
||||
from twisted.conch.insults.insults import (
|
||||
NORMAL, BOLD, UNDERLINE, BLINK, REVERSE_VIDEO)
|
||||
|
||||
from twisted.trial import unittest
|
||||
|
||||
WIDTH = 80
|
||||
HEIGHT = 24
|
||||
|
||||
class BufferTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.term = helper.TerminalBuffer()
|
||||
self.term.connectionMade()
|
||||
|
||||
def testInitialState(self):
|
||||
self.assertEqual(self.term.width, WIDTH)
|
||||
self.assertEqual(self.term.height, HEIGHT)
|
||||
self.assertEqual(str(self.term),
|
||||
'\n' * (HEIGHT - 1))
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
|
||||
|
||||
def test_initialPrivateModes(self):
|
||||
"""
|
||||
Verify that only DEC Auto Wrap Mode (DECAWM) and DEC Text Cursor Enable
|
||||
Mode (DECTCEM) are initially in the Set Mode (SM) state.
|
||||
"""
|
||||
self.assertEqual(
|
||||
{privateModes.AUTO_WRAP: True,
|
||||
privateModes.CURSOR_MODE: True},
|
||||
self.term.privateModes)
|
||||
|
||||
|
||||
def test_carriageReturn(self):
|
||||
"""
|
||||
C{"\r"} moves the cursor to the first column in the current row.
|
||||
"""
|
||||
self.term.cursorForward(5)
|
||||
self.term.cursorDown(3)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
|
||||
self.term.insertAtCursor("\r")
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
|
||||
|
||||
|
||||
def test_linefeed(self):
|
||||
"""
|
||||
C{"\n"} moves the cursor to the next row without changing the column.
|
||||
"""
|
||||
self.term.cursorForward(5)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 0))
|
||||
self.term.insertAtCursor("\n")
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
|
||||
|
||||
|
||||
def test_newline(self):
|
||||
"""
|
||||
C{write} transforms C{"\n"} into C{"\r\n"}.
|
||||
"""
|
||||
self.term.cursorForward(5)
|
||||
self.term.cursorDown(3)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 3))
|
||||
self.term.write("\n")
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
|
||||
|
||||
|
||||
def test_setPrivateModes(self):
|
||||
"""
|
||||
Verify that L{helper.TerminalBuffer.setPrivateModes} changes the Set
|
||||
Mode (SM) state to "set" for the private modes it is passed.
|
||||
"""
|
||||
expected = self.term.privateModes.copy()
|
||||
self.term.setPrivateModes([privateModes.SCROLL, privateModes.SCREEN])
|
||||
expected[privateModes.SCROLL] = True
|
||||
expected[privateModes.SCREEN] = True
|
||||
self.assertEqual(expected, self.term.privateModes)
|
||||
|
||||
|
||||
def test_resetPrivateModes(self):
|
||||
"""
|
||||
Verify that L{helper.TerminalBuffer.resetPrivateModes} changes the Set
|
||||
Mode (SM) state to "reset" for the private modes it is passed.
|
||||
"""
|
||||
expected = self.term.privateModes.copy()
|
||||
self.term.resetPrivateModes([privateModes.AUTO_WRAP, privateModes.CURSOR_MODE])
|
||||
del expected[privateModes.AUTO_WRAP]
|
||||
del expected[privateModes.CURSOR_MODE]
|
||||
self.assertEqual(expected, self.term.privateModes)
|
||||
|
||||
|
||||
def testCursorDown(self):
|
||||
self.term.cursorDown(3)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 3))
|
||||
self.term.cursorDown()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 4))
|
||||
self.term.cursorDown(HEIGHT)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
|
||||
|
||||
def testCursorUp(self):
|
||||
self.term.cursorUp(5)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
|
||||
self.term.cursorDown(20)
|
||||
self.term.cursorUp(1)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 19))
|
||||
|
||||
self.term.cursorUp(19)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
|
||||
def testCursorForward(self):
|
||||
self.term.cursorForward(2)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (2, 0))
|
||||
self.term.cursorForward(2)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (4, 0))
|
||||
self.term.cursorForward(WIDTH)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (WIDTH, 0))
|
||||
|
||||
def testCursorBackward(self):
|
||||
self.term.cursorForward(10)
|
||||
self.term.cursorBackward(2)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (8, 0))
|
||||
self.term.cursorBackward(7)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (1, 0))
|
||||
self.term.cursorBackward(1)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
self.term.cursorBackward(1)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
|
||||
def testCursorPositioning(self):
|
||||
self.term.cursorPosition(3, 9)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (3, 9))
|
||||
|
||||
def testSimpleWriting(self):
|
||||
s = "Hello, world."
|
||||
self.term.write(s)
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testOvertype(self):
|
||||
s = "hello, world."
|
||||
self.term.write(s)
|
||||
self.term.cursorBackward(len(s))
|
||||
self.term.resetModes([modes.IRM])
|
||||
self.term.write("H")
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
("H" + s[1:]) + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testInsert(self):
|
||||
s = "ello, world."
|
||||
self.term.write(s)
|
||||
self.term.cursorBackward(len(s))
|
||||
self.term.setModes([modes.IRM])
|
||||
self.term.write("H")
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
("H" + s) + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testWritingInTheMiddle(self):
|
||||
s = "Hello, world."
|
||||
self.term.cursorDown(5)
|
||||
self.term.cursorForward(5)
|
||||
self.term.write(s)
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
'\n' * 5 +
|
||||
(self.term.fill * 5) + s + '\n' +
|
||||
'\n' * (HEIGHT - 7))
|
||||
|
||||
def testWritingWrappedAtEndOfLine(self):
|
||||
s = "Hello, world."
|
||||
self.term.cursorForward(WIDTH - 5)
|
||||
self.term.write(s)
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s[:5].rjust(WIDTH) + '\n' +
|
||||
s[5:] + '\n' +
|
||||
'\n' * (HEIGHT - 3))
|
||||
|
||||
def testIndex(self):
|
||||
self.term.index()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
|
||||
self.term.cursorDown(HEIGHT)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
|
||||
self.term.index()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, HEIGHT - 1))
|
||||
|
||||
def testReverseIndex(self):
|
||||
self.term.reverseIndex()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 0))
|
||||
self.term.cursorDown(2)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
|
||||
self.term.reverseIndex()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
|
||||
|
||||
def test_nextLine(self):
|
||||
"""
|
||||
C{nextLine} positions the cursor at the beginning of the row below the
|
||||
current row.
|
||||
"""
|
||||
self.term.nextLine()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 1))
|
||||
self.term.cursorForward(5)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (5, 1))
|
||||
self.term.nextLine()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (0, 2))
|
||||
|
||||
def testSaveCursor(self):
|
||||
self.term.cursorDown(5)
|
||||
self.term.cursorForward(7)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
|
||||
self.term.saveCursor()
|
||||
self.term.cursorDown(7)
|
||||
self.term.cursorBackward(3)
|
||||
self.assertEqual(self.term.reportCursorPosition(), (4, 12))
|
||||
self.term.restoreCursor()
|
||||
self.assertEqual(self.term.reportCursorPosition(), (7, 5))
|
||||
|
||||
def testSingleShifts(self):
|
||||
self.term.singleShift2()
|
||||
self.term.write('Hi')
|
||||
|
||||
ch = self.term.getCharacter(0, 0)
|
||||
self.assertEqual(ch[0], 'H')
|
||||
self.assertEqual(ch[1].charset, G2)
|
||||
|
||||
ch = self.term.getCharacter(1, 0)
|
||||
self.assertEqual(ch[0], 'i')
|
||||
self.assertEqual(ch[1].charset, G0)
|
||||
|
||||
self.term.singleShift3()
|
||||
self.term.write('!!')
|
||||
|
||||
ch = self.term.getCharacter(2, 0)
|
||||
self.assertEqual(ch[0], '!')
|
||||
self.assertEqual(ch[1].charset, G3)
|
||||
|
||||
ch = self.term.getCharacter(3, 0)
|
||||
self.assertEqual(ch[0], '!')
|
||||
self.assertEqual(ch[1].charset, G0)
|
||||
|
||||
def testShifting(self):
|
||||
s1 = "Hello"
|
||||
s2 = "World"
|
||||
s3 = "Bye!"
|
||||
self.term.write("Hello\n")
|
||||
self.term.shiftOut()
|
||||
self.term.write("World\n")
|
||||
self.term.shiftIn()
|
||||
self.term.write("Bye!\n")
|
||||
|
||||
g = G0
|
||||
h = 0
|
||||
for s in (s1, s2, s3):
|
||||
for i in range(len(s)):
|
||||
ch = self.term.getCharacter(i, h)
|
||||
self.assertEqual(ch[0], s[i])
|
||||
self.assertEqual(ch[1].charset, g)
|
||||
g = g == G0 and G1 or G0
|
||||
h += 1
|
||||
|
||||
def testGraphicRendition(self):
|
||||
self.term.selectGraphicRendition(BOLD, UNDERLINE, BLINK, REVERSE_VIDEO)
|
||||
self.term.write('W')
|
||||
self.term.selectGraphicRendition(NORMAL)
|
||||
self.term.write('X')
|
||||
self.term.selectGraphicRendition(BLINK)
|
||||
self.term.write('Y')
|
||||
self.term.selectGraphicRendition(BOLD)
|
||||
self.term.write('Z')
|
||||
|
||||
ch = self.term.getCharacter(0, 0)
|
||||
self.assertEqual(ch[0], 'W')
|
||||
self.assertTrue(ch[1].bold)
|
||||
self.assertTrue(ch[1].underline)
|
||||
self.assertTrue(ch[1].blink)
|
||||
self.assertTrue(ch[1].reverseVideo)
|
||||
|
||||
ch = self.term.getCharacter(1, 0)
|
||||
self.assertEqual(ch[0], 'X')
|
||||
self.assertFalse(ch[1].bold)
|
||||
self.assertFalse(ch[1].underline)
|
||||
self.assertFalse(ch[1].blink)
|
||||
self.assertFalse(ch[1].reverseVideo)
|
||||
|
||||
ch = self.term.getCharacter(2, 0)
|
||||
self.assertEqual(ch[0], 'Y')
|
||||
self.assertTrue(ch[1].blink)
|
||||
self.assertFalse(ch[1].bold)
|
||||
self.assertFalse(ch[1].underline)
|
||||
self.assertFalse(ch[1].reverseVideo)
|
||||
|
||||
ch = self.term.getCharacter(3, 0)
|
||||
self.assertEqual(ch[0], 'Z')
|
||||
self.assertTrue(ch[1].blink)
|
||||
self.assertTrue(ch[1].bold)
|
||||
self.assertFalse(ch[1].underline)
|
||||
self.assertFalse(ch[1].reverseVideo)
|
||||
|
||||
def testColorAttributes(self):
|
||||
s1 = "Merry xmas"
|
||||
s2 = "Just kidding"
|
||||
self.term.selectGraphicRendition(helper.FOREGROUND + helper.RED,
|
||||
helper.BACKGROUND + helper.GREEN)
|
||||
self.term.write(s1 + "\n")
|
||||
self.term.selectGraphicRendition(NORMAL)
|
||||
self.term.write(s2 + "\n")
|
||||
|
||||
for i in range(len(s1)):
|
||||
ch = self.term.getCharacter(i, 0)
|
||||
self.assertEqual(ch[0], s1[i])
|
||||
self.assertEqual(ch[1].charset, G0)
|
||||
self.assertEqual(ch[1].bold, False)
|
||||
self.assertEqual(ch[1].underline, False)
|
||||
self.assertEqual(ch[1].blink, False)
|
||||
self.assertEqual(ch[1].reverseVideo, False)
|
||||
self.assertEqual(ch[1].foreground, helper.RED)
|
||||
self.assertEqual(ch[1].background, helper.GREEN)
|
||||
|
||||
for i in range(len(s2)):
|
||||
ch = self.term.getCharacter(i, 1)
|
||||
self.assertEqual(ch[0], s2[i])
|
||||
self.assertEqual(ch[1].charset, G0)
|
||||
self.assertEqual(ch[1].bold, False)
|
||||
self.assertEqual(ch[1].underline, False)
|
||||
self.assertEqual(ch[1].blink, False)
|
||||
self.assertEqual(ch[1].reverseVideo, False)
|
||||
self.assertEqual(ch[1].foreground, helper.WHITE)
|
||||
self.assertEqual(ch[1].background, helper.BLACK)
|
||||
|
||||
def testEraseLine(self):
|
||||
s1 = 'line 1'
|
||||
s2 = 'line 2'
|
||||
s3 = 'line 3'
|
||||
self.term.write('\n'.join((s1, s2, s3)) + '\n')
|
||||
self.term.cursorPosition(1, 1)
|
||||
self.term.eraseLine()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s1 + '\n' +
|
||||
'\n' +
|
||||
s3 + '\n' +
|
||||
'\n' * (HEIGHT - 4))
|
||||
|
||||
def testEraseToLineEnd(self):
|
||||
s = 'Hello, world.'
|
||||
self.term.write(s)
|
||||
self.term.cursorBackward(5)
|
||||
self.term.eraseToLineEnd()
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s[:-5] + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testEraseToLineBeginning(self):
|
||||
s = 'Hello, world.'
|
||||
self.term.write(s)
|
||||
self.term.cursorBackward(5)
|
||||
self.term.eraseToLineBeginning()
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s[-4:].rjust(len(s)) + '\n' +
|
||||
'\n' * (HEIGHT - 2))
|
||||
|
||||
def testEraseDisplay(self):
|
||||
self.term.write('Hello world\n')
|
||||
self.term.write('Goodbye world\n')
|
||||
self.term.eraseDisplay()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
'\n' * (HEIGHT - 1))
|
||||
|
||||
def testEraseToDisplayEnd(self):
|
||||
s1 = "Hello world"
|
||||
s2 = "Goodbye world"
|
||||
self.term.write('\n'.join((s1, s2, '')))
|
||||
self.term.cursorPosition(5, 1)
|
||||
self.term.eraseToDisplayEnd()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s1 + '\n' +
|
||||
s2[:5] + '\n' +
|
||||
'\n' * (HEIGHT - 3))
|
||||
|
||||
def testEraseToDisplayBeginning(self):
|
||||
s1 = "Hello world"
|
||||
s2 = "Goodbye world"
|
||||
self.term.write('\n'.join((s1, s2)))
|
||||
self.term.cursorPosition(5, 1)
|
||||
self.term.eraseToDisplayBeginning()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
'\n' +
|
||||
s2[6:].rjust(len(s2)) + '\n' +
|
||||
'\n' * (HEIGHT - 3))
|
||||
|
||||
def testLineInsertion(self):
|
||||
s1 = "Hello world"
|
||||
s2 = "Goodbye world"
|
||||
self.term.write('\n'.join((s1, s2)))
|
||||
self.term.cursorPosition(7, 1)
|
||||
self.term.insertLine()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s1 + '\n' +
|
||||
'\n' +
|
||||
s2 + '\n' +
|
||||
'\n' * (HEIGHT - 4))
|
||||
|
||||
def testLineDeletion(self):
|
||||
s1 = "Hello world"
|
||||
s2 = "Middle words"
|
||||
s3 = "Goodbye world"
|
||||
self.term.write('\n'.join((s1, s2, s3)))
|
||||
self.term.cursorPosition(9, 1)
|
||||
self.term.deleteLine()
|
||||
|
||||
self.assertEqual(
|
||||
str(self.term),
|
||||
s1 + '\n' +
|
||||
s3 + '\n' +
|
||||
'\n' * (HEIGHT - 3))
|
||||
|
||||
class FakeDelayedCall:
|
||||
called = False
|
||||
cancelled = False
|
||||
def __init__(self, fs, timeout, f, a, kw):
|
||||
self.fs = fs
|
||||
self.timeout = timeout
|
||||
self.f = f
|
||||
self.a = a
|
||||
self.kw = kw
|
||||
|
||||
def active(self):
|
||||
return not (self.cancelled or self.called)
|
||||
|
||||
def cancel(self):
|
||||
self.cancelled = True
|
||||
# self.fs.calls.remove(self)
|
||||
|
||||
def call(self):
|
||||
self.called = True
|
||||
self.f(*self.a, **self.kw)
|
||||
|
||||
class FakeScheduler:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def callLater(self, timeout, f, *a, **kw):
|
||||
self.calls.append(FakeDelayedCall(self, timeout, f, a, kw))
|
||||
return self.calls[-1]
|
||||
|
||||
class ExpectTests(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.term = helper.ExpectableBuffer()
|
||||
self.term.connectionMade()
|
||||
self.fs = FakeScheduler()
|
||||
|
||||
def testSimpleString(self):
|
||||
result = []
|
||||
d = self.term.expect("hello world", timeout=1, scheduler=self.fs)
|
||||
d.addCallback(result.append)
|
||||
|
||||
self.term.write("greeting puny earthlings\n")
|
||||
self.assertFalse(result)
|
||||
self.term.write("hello world\n")
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(result[0].group(), "hello world")
|
||||
self.assertEqual(len(self.fs.calls), 1)
|
||||
self.assertFalse(self.fs.calls[0].active())
|
||||
|
||||
def testBrokenUpString(self):
|
||||
result = []
|
||||
d = self.term.expect("hello world")
|
||||
d.addCallback(result.append)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.term.write("hello ")
|
||||
self.assertFalse(result)
|
||||
self.term.write("worl")
|
||||
self.assertFalse(result)
|
||||
self.term.write("d")
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(result[0].group(), "hello world")
|
||||
|
||||
|
||||
def testMultiple(self):
|
||||
result = []
|
||||
d1 = self.term.expect("hello ")
|
||||
d1.addCallback(result.append)
|
||||
d2 = self.term.expect("world")
|
||||
d2.addCallback(result.append)
|
||||
|
||||
self.assertFalse(result)
|
||||
self.term.write("hello")
|
||||
self.assertFalse(result)
|
||||
self.term.write(" ")
|
||||
self.assertEqual(len(result), 1)
|
||||
self.term.write("world")
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].group(), "hello ")
|
||||
self.assertEqual(result[1].group(), "world")
|
||||
|
||||
def testSynchronous(self):
|
||||
self.term.write("hello world")
|
||||
|
||||
result = []
|
||||
d = self.term.expect("hello world")
|
||||
d.addCallback(result.append)
|
||||
self.assertTrue(result)
|
||||
self.assertEqual(result[0].group(), "hello world")
|
||||
|
||||
def testMultipleSynchronous(self):
|
||||
self.term.write("goodbye world")
|
||||
|
||||
result = []
|
||||
d1 = self.term.expect("bye")
|
||||
d1.addCallback(result.append)
|
||||
d2 = self.term.expect("world")
|
||||
d2.addCallback(result.append)
|
||||
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result[0].group(), "bye")
|
||||
self.assertEqual(result[1].group(), "world")
|
||||
|
||||
def _cbTestTimeoutFailure(self, res):
|
||||
self.assertTrue(hasattr(res, 'type'))
|
||||
self.assertEqual(res.type, helper.ExpectationTimeout)
|
||||
|
||||
def testTimeoutFailure(self):
|
||||
d = self.term.expect("hello world", timeout=1, scheduler=self.fs)
|
||||
d.addBoth(self._cbTestTimeoutFailure)
|
||||
self.fs.calls[0].call()
|
||||
|
||||
def testOverlappingTimeout(self):
|
||||
self.term.write("not zoomtastic")
|
||||
|
||||
result = []
|
||||
d1 = self.term.expect("hello world", timeout=1, scheduler=self.fs)
|
||||
d1.addBoth(self._cbTestTimeoutFailure)
|
||||
d2 = self.term.expect("zoom")
|
||||
d2.addCallback(result.append)
|
||||
|
||||
self.fs.calls[0].call()
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0].group(), "zoom")
|
||||
|
||||
|
||||
|
||||
class CharacterAttributeTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for L{twisted.conch.insults.helper.CharacterAttribute}.
|
||||
"""
|
||||
def test_equality(self):
|
||||
"""
|
||||
L{CharacterAttribute}s must have matching character attribute values
|
||||
(bold, blink, underline, etc) with the same values to be considered
|
||||
equal.
|
||||
"""
|
||||
self.assertEqual(
|
||||
helper.CharacterAttribute(),
|
||||
helper.CharacterAttribute())
|
||||
|
||||
self.assertEqual(
|
||||
helper.CharacterAttribute(),
|
||||
helper.CharacterAttribute(charset=G0))
|
||||
|
||||
self.assertEqual(
|
||||
helper.CharacterAttribute(
|
||||
bold=True, underline=True, blink=False, reverseVideo=True,
|
||||
foreground=helper.BLUE),
|
||||
helper.CharacterAttribute(
|
||||
bold=True, underline=True, blink=False, reverseVideo=True,
|
||||
foreground=helper.BLUE))
|
||||
|
||||
self.assertNotEqual(
|
||||
helper.CharacterAttribute(),
|
||||
helper.CharacterAttribute(charset=G1))
|
||||
|
||||
self.assertNotEqual(
|
||||
helper.CharacterAttribute(bold=True),
|
||||
helper.CharacterAttribute(bold=False))
|
||||
|
||||
|
||||
def test_wantOneDeprecated(self):
|
||||
"""
|
||||
L{twisted.conch.insults.helper.CharacterAttribute.wantOne} emits
|
||||
a deprecation warning when invoked.
|
||||
"""
|
||||
# Trigger the deprecation warning.
|
||||
helper._FormattingState().wantOne(bold=True)
|
||||
|
||||
warningsShown = self.flushWarnings([self.test_wantOneDeprecated])
|
||||
self.assertEqual(len(warningsShown), 1)
|
||||
self.assertEqual(warningsShown[0]['category'], DeprecationWarning)
|
||||
self.assertEqual(
|
||||
warningsShown[0]['message'],
|
||||
'twisted.conch.insults.helper.wantOne was deprecated in '
|
||||
'Twisted 13.1.0')
|
@ -0,0 +1,497 @@
|
||||
# -*- test-case-name: twisted.conch.test.test_insults -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
from twisted.python.reflect import namedAny
|
||||
from twisted.trial import unittest
|
||||
from twisted.test.proto_helpers import StringTransport
|
||||
|
||||
from twisted.conch.insults.insults import ServerProtocol, ClientProtocol
|
||||
from twisted.conch.insults.insults import CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL
|
||||
from twisted.conch.insults.insults import G0, G1
|
||||
from twisted.conch.insults.insults import modes
|
||||
|
||||
def _getattr(mock, name):
|
||||
return super(Mock, mock).__getattribute__(name)
|
||||
|
||||
def occurrences(mock):
|
||||
return _getattr(mock, 'occurrences')
|
||||
|
||||
def methods(mock):
|
||||
return _getattr(mock, 'methods')
|
||||
|
||||
def _append(mock, obj):
|
||||
occurrences(mock).append(obj)
|
||||
|
||||
default = object()
|
||||
|
||||
class Mock(object):
|
||||
callReturnValue = default
|
||||
|
||||
def __init__(self, methods=None, callReturnValue=default):
|
||||
"""
|
||||
@param methods: Mapping of names to return values
|
||||
@param callReturnValue: object __call__ should return
|
||||
"""
|
||||
self.occurrences = []
|
||||
if methods is None:
|
||||
methods = {}
|
||||
self.methods = methods
|
||||
if callReturnValue is not default:
|
||||
self.callReturnValue = callReturnValue
|
||||
|
||||
def __call__(self, *a, **kw):
|
||||
returnValue = _getattr(self, 'callReturnValue')
|
||||
if returnValue is default:
|
||||
returnValue = Mock()
|
||||
# _getattr(self, 'occurrences').append(('__call__', returnValue, a, kw))
|
||||
_append(self, ('__call__', returnValue, a, kw))
|
||||
return returnValue
|
||||
|
||||
def __getattribute__(self, name):
|
||||
methods = _getattr(self, 'methods')
|
||||
if name in methods:
|
||||
attrValue = Mock(callReturnValue=methods[name])
|
||||
else:
|
||||
attrValue = Mock()
|
||||
# _getattr(self, 'occurrences').append((name, attrValue))
|
||||
_append(self, (name, attrValue))
|
||||
return attrValue
|
||||
|
||||
class MockMixin:
|
||||
def assertCall(self, occurrence, methodName, expectedPositionalArgs=(),
|
||||
expectedKeywordArgs={}):
|
||||
attr, mock = occurrence
|
||||
self.assertEqual(attr, methodName)
|
||||
self.assertEqual(len(occurrences(mock)), 1)
|
||||
[(call, result, args, kw)] = occurrences(mock)
|
||||
self.assertEqual(call, "__call__")
|
||||
self.assertEqual(args, expectedPositionalArgs)
|
||||
self.assertEqual(kw, expectedKeywordArgs)
|
||||
return result
|
||||
|
||||
|
||||
_byteGroupingTestTemplate = """\
|
||||
def testByte%(groupName)s(self):
|
||||
transport = StringTransport()
|
||||
proto = Mock()
|
||||
parser = self.protocolFactory(lambda: proto)
|
||||
parser.factory = self
|
||||
parser.makeConnection(transport)
|
||||
|
||||
bytes = self.TEST_BYTES
|
||||
while bytes:
|
||||
chunk = bytes[:%(bytesPer)d]
|
||||
bytes = bytes[%(bytesPer)d:]
|
||||
parser.dataReceived(chunk)
|
||||
|
||||
self.verifyResults(transport, proto, parser)
|
||||
"""
|
||||
class ByteGroupingsMixin(MockMixin):
|
||||
protocolFactory = None
|
||||
|
||||
for word, n in [('Pairs', 2), ('Triples', 3), ('Quads', 4), ('Quints', 5), ('Sexes', 6)]:
|
||||
exec _byteGroupingTestTemplate % {'groupName': word, 'bytesPer': n}
|
||||
del word, n
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
result = self.assertCall(occurrences(proto).pop(0), "makeConnection", (parser,))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
|
||||
del _byteGroupingTestTemplate
|
||||
|
||||
class ServerArrowKeysTests(ByteGroupingsMixin, unittest.TestCase):
|
||||
protocolFactory = ServerProtocol
|
||||
|
||||
# All the arrow keys once
|
||||
TEST_BYTES = '\x1b[A\x1b[B\x1b[C\x1b[D'
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
|
||||
|
||||
for arrow in (parser.UP_ARROW, parser.DOWN_ARROW,
|
||||
parser.RIGHT_ARROW, parser.LEFT_ARROW):
|
||||
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (arrow, None))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
self.assertFalse(occurrences(proto))
|
||||
|
||||
|
||||
class PrintableCharactersTests(ByteGroupingsMixin, unittest.TestCase):
|
||||
protocolFactory = ServerProtocol
|
||||
|
||||
# Some letters and digits, first on their own, then capitalized,
|
||||
# then modified with alt
|
||||
|
||||
TEST_BYTES = 'abc123ABC!@#\x1ba\x1bb\x1bc\x1b1\x1b2\x1b3'
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
|
||||
|
||||
for char in 'abc123ABC!@#':
|
||||
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, None))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
|
||||
for char in 'abc123':
|
||||
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (char, parser.ALT))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
|
||||
occs = occurrences(proto)
|
||||
self.assertFalse(occs, "%r should have been []" % (occs,))
|
||||
|
||||
class ServerFunctionKeysTests(ByteGroupingsMixin, unittest.TestCase):
|
||||
"""Test for parsing and dispatching function keys (F1 - F12)
|
||||
"""
|
||||
protocolFactory = ServerProtocol
|
||||
|
||||
byteList = []
|
||||
for bytes in ('OP', 'OQ', 'OR', 'OS', # F1 - F4
|
||||
'15~', '17~', '18~', '19~', # F5 - F8
|
||||
'20~', '21~', '23~', '24~'): # F9 - F12
|
||||
byteList.append('\x1b[' + bytes)
|
||||
TEST_BYTES = ''.join(byteList)
|
||||
del byteList, bytes
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
|
||||
for funcNum in range(1, 13):
|
||||
funcArg = getattr(parser, 'F%d' % (funcNum,))
|
||||
result = self.assertCall(occurrences(proto).pop(0), "keystrokeReceived", (funcArg, None))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
self.assertFalse(occurrences(proto))
|
||||
|
||||
class ClientCursorMovementTests(ByteGroupingsMixin, unittest.TestCase):
|
||||
protocolFactory = ClientProtocol
|
||||
|
||||
d2 = "\x1b[2B"
|
||||
r4 = "\x1b[4C"
|
||||
u1 = "\x1b[A"
|
||||
l2 = "\x1b[2D"
|
||||
# Move the cursor down two, right four, up one, left two, up one, left two
|
||||
TEST_BYTES = d2 + r4 + u1 + l2 + u1 + l2
|
||||
del d2, r4, u1, l2
|
||||
|
||||
def verifyResults(self, transport, proto, parser):
|
||||
ByteGroupingsMixin.verifyResults(self, transport, proto, parser)
|
||||
|
||||
for (method, count) in [('Down', 2), ('Forward', 4), ('Up', 1),
|
||||
('Backward', 2), ('Up', 1), ('Backward', 2)]:
|
||||
result = self.assertCall(occurrences(proto).pop(0), "cursor" + method, (count,))
|
||||
self.assertEqual(occurrences(result), [])
|
||||
self.assertFalse(occurrences(proto))
|
||||
|
||||
class ClientControlSequencesTests(unittest.TestCase, MockMixin):
|
||||
def setUp(self):
|
||||
self.transport = StringTransport()
|
||||
self.proto = Mock()
|
||||
self.parser = ClientProtocol(lambda: self.proto)
|
||||
self.parser.factory = self
|
||||
self.parser.makeConnection(self.transport)
|
||||
result = self.assertCall(occurrences(self.proto).pop(0), "makeConnection", (self.parser,))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
def testSimpleCardinals(self):
|
||||
self.parser.dataReceived(
|
||||
''.join([''.join(['\x1b[' + str(n) + ch for n in ('', 2, 20, 200)]) for ch in 'BACD']))
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for meth in ("Down", "Up", "Forward", "Backward"):
|
||||
for count in (1, 2, 20, 200):
|
||||
result = self.assertCall(occs.pop(0), "cursor" + meth, (count,))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testScrollRegion(self):
|
||||
self.parser.dataReceived('\x1b[5;22r\x1b[r')
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "setScrollRegion", (5, 22))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "setScrollRegion", (None, None))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testHeightAndWidth(self):
|
||||
self.parser.dataReceived("\x1b#3\x1b#4\x1b#5\x1b#6")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "doubleHeightLine", (True,))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "doubleHeightLine", (False,))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "singleWidthLine")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "doubleWidthLine")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testCharacterSet(self):
|
||||
self.parser.dataReceived(
|
||||
''.join([''.join(['\x1b' + g + n for n in 'AB012']) for g in '()']))
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for which in (G0, G1):
|
||||
for charset in (CS_UK, CS_US, CS_DRAWING, CS_ALTERNATE, CS_ALTERNATE_SPECIAL):
|
||||
result = self.assertCall(occs.pop(0), "selectCharacterSet", (charset, which))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testShifting(self):
|
||||
self.parser.dataReceived("\x15\x14")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "shiftIn")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "shiftOut")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testSingleShifts(self):
|
||||
self.parser.dataReceived("\x1bN\x1bO")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "singleShift2")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "singleShift3")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testKeypadMode(self):
|
||||
self.parser.dataReceived("\x1b=\x1b>")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "applicationKeypadMode")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "numericKeypadMode")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testCursor(self):
|
||||
self.parser.dataReceived("\x1b7\x1b8")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "saveCursor")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "restoreCursor")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testReset(self):
|
||||
self.parser.dataReceived("\x1bc")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "reset")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testIndex(self):
|
||||
self.parser.dataReceived("\x1bD\x1bM\x1bE")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "index")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "reverseIndex")
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "nextLine")
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testModes(self):
|
||||
self.parser.dataReceived(
|
||||
"\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "h")
|
||||
self.parser.dataReceived(
|
||||
"\x1b[" + ';'.join(map(str, [modes.KAM, modes.IRM, modes.LNM])) + "l")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "setModes", ([modes.KAM, modes.IRM, modes.LNM],))
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "resetModes", ([modes.KAM, modes.IRM, modes.LNM],))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testErasure(self):
|
||||
self.parser.dataReceived(
|
||||
"\x1b[K\x1b[1K\x1b[2K\x1b[J\x1b[1J\x1b[2J\x1b[3P")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for meth in ("eraseToLineEnd", "eraseToLineBeginning", "eraseLine",
|
||||
"eraseToDisplayEnd", "eraseToDisplayBeginning",
|
||||
"eraseDisplay"):
|
||||
result = self.assertCall(occs.pop(0), meth)
|
||||
self.assertFalse(occurrences(result))
|
||||
|
||||
result = self.assertCall(occs.pop(0), "deleteCharacter", (3,))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testLineDeletion(self):
|
||||
self.parser.dataReceived("\x1b[M\x1b[3M")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for arg in (1, 3):
|
||||
result = self.assertCall(occs.pop(0), "deleteLine", (arg,))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testLineInsertion(self):
|
||||
self.parser.dataReceived("\x1b[L\x1b[3L")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
for arg in (1, 3):
|
||||
result = self.assertCall(occs.pop(0), "insertLine", (arg,))
|
||||
self.assertFalse(occurrences(result))
|
||||
self.assertFalse(occs)
|
||||
|
||||
def testCursorPosition(self):
|
||||
methods(self.proto)['reportCursorPosition'] = (6, 7)
|
||||
self.parser.dataReceived("\x1b[6n")
|
||||
self.assertEqual(self.transport.value(), "\x1b[7;8R")
|
||||
occs = occurrences(self.proto)
|
||||
|
||||
result = self.assertCall(occs.pop(0), "reportCursorPosition")
|
||||
# This isn't really an interesting assert, since it only tests that
|
||||
# our mock setup is working right, but I'll include it anyway.
|
||||
self.assertEqual(result, (6, 7))
|
||||
|
||||
|
||||
def test_applicationDataBytes(self):
|
||||
"""
|
||||
Contiguous non-control bytes are passed to a single call to the
|
||||
C{write} method of the terminal to which the L{ClientProtocol} is
|
||||
connected.
|
||||
"""
|
||||
occs = occurrences(self.proto)
|
||||
self.parser.dataReceived('a')
|
||||
self.assertCall(occs.pop(0), "write", ("a",))
|
||||
self.parser.dataReceived('bc')
|
||||
self.assertCall(occs.pop(0), "write", ("bc",))
|
||||
|
||||
|
||||
def _applicationDataTest(self, data, calls):
|
||||
occs = occurrences(self.proto)
|
||||
self.parser.dataReceived(data)
|
||||
while calls:
|
||||
self.assertCall(occs.pop(0), *calls.pop(0))
|
||||
self.assertFalse(occs, "No other calls should happen: %r" % (occs,))
|
||||
|
||||
|
||||
def test_shiftInAfterApplicationData(self):
|
||||
"""
|
||||
Application data bytes followed by a shift-in command are passed to a
|
||||
call to C{write} before the terminal's C{shiftIn} method is called.
|
||||
"""
|
||||
self._applicationDataTest(
|
||||
'ab\x15', [
|
||||
("write", ("ab",)),
|
||||
("shiftIn",)])
|
||||
|
||||
|
||||
def test_shiftOutAfterApplicationData(self):
|
||||
"""
|
||||
Application data bytes followed by a shift-out command are passed to a
|
||||
call to C{write} before the terminal's C{shiftOut} method is called.
|
||||
"""
|
||||
self._applicationDataTest(
|
||||
'ab\x14', [
|
||||
("write", ("ab",)),
|
||||
("shiftOut",)])
|
||||
|
||||
|
||||
def test_cursorBackwardAfterApplicationData(self):
|
||||
"""
|
||||
Application data bytes followed by a cursor-backward command are passed
|
||||
to a call to C{write} before the terminal's C{cursorBackward} method is
|
||||
called.
|
||||
"""
|
||||
self._applicationDataTest(
|
||||
'ab\x08', [
|
||||
("write", ("ab",)),
|
||||
("cursorBackward",)])
|
||||
|
||||
|
||||
def test_escapeAfterApplicationData(self):
|
||||
"""
|
||||
Application data bytes followed by an escape character are passed to a
|
||||
call to C{write} before the terminal's handler method for the escape is
|
||||
called.
|
||||
"""
|
||||
# Test a short escape
|
||||
self._applicationDataTest(
|
||||
'ab\x1bD', [
|
||||
("write", ("ab",)),
|
||||
("index",)])
|
||||
|
||||
# And a long escape
|
||||
self._applicationDataTest(
|
||||
'ab\x1b[4h', [
|
||||
("write", ("ab",)),
|
||||
("setModes", ([4],))])
|
||||
|
||||
# There's some other cases too, but they're all handled by the same
|
||||
# codepaths as above.
|
||||
|
||||
|
||||
|
||||
class ServerProtocolOutputTests(unittest.TestCase):
|
||||
"""
|
||||
Tests for the bytes L{ServerProtocol} writes to its transport when its
|
||||
methods are called.
|
||||
"""
|
||||
def test_nextLine(self):
|
||||
"""
|
||||
L{ServerProtocol.nextLine} writes C{"\r\n"} to its transport.
|
||||
"""
|
||||
# Why doesn't it write ESC E? Because ESC E is poorly supported. For
|
||||
# example, gnome-terminal (many different versions) fails to scroll if
|
||||
# it receives ESC E and the cursor is already on the last row.
|
||||
protocol = ServerProtocol()
|
||||
transport = StringTransport()
|
||||
protocol.makeConnection(transport)
|
||||
protocol.nextLine()
|
||||
self.assertEqual(transport.value(), "\r\n")
|
||||
|
||||
|
||||
|
||||
class DeprecationsTests(unittest.TestCase):
|
||||
"""
|
||||
Tests to ensure deprecation of L{insults.colors} and L{insults.client}
|
||||
"""
|
||||
|
||||
def ensureDeprecated(self, message):
|
||||
"""
|
||||
Ensures that the correct deprecation warning was issued.
|
||||
"""
|
||||
warnings = self.flushWarnings()
|
||||
self.assertIs(warnings[0]['category'], DeprecationWarning)
|
||||
self.assertEqual(warnings[0]['message'], message)
|
||||
self.assertEqual(len(warnings), 1)
|
||||
|
||||
|
||||
def test_colors(self):
|
||||
"""
|
||||
The L{insults.colors} module is deprecated
|
||||
"""
|
||||
namedAny('twisted.conch.insults.colors')
|
||||
self.ensureDeprecated("twisted.conch.insults.colors was deprecated "
|
||||
"in Twisted 10.1.0: Please use "
|
||||
"twisted.conch.insults.helper instead.")
|
||||
|
||||
|
||||
def test_client(self):
|
||||
"""
|
||||
The L{insults.client} module is deprecated
|
||||
"""
|
||||
namedAny('twisted.conch.insults.client')
|
||||
self.ensureDeprecated("twisted.conch.insults.client was deprecated "
|
||||
"in Twisted 10.1.0: Please use "
|
||||
"twisted.conch.insults.insults instead.")
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user