Allow loading previously exported SSH public keys from a file

This commit is contained in:
Roman Zeyde 2017-05-13 12:47:48 +03:00
parent 53d43cba29
commit f358ca29d4
No known key found for this signature in database
GPG Key ID: 87CAE5FA46917CBB

View File

@ -2,6 +2,7 @@
import argparse
import contextlib
import functools
import io
import logging
import os
import re
@ -133,27 +134,36 @@ def handle_connection_error(func):
return wrapper
def parse_config(fname):
def parse_config(contents):
"""Parse config file into a list of Identity objects."""
contents = open(fname).read()
for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents):
yield device.interface.Identity(identity_str=identity_str,
curve_name=curve_name)
def import_public_keys(contents):
"""Load (previously exported) SSH public keys from a file's contents."""
for line in io.StringIO(contents):
# Verify this line represents valid SSH public key
formats.import_public_key(line)
yield line
class JustInTimeConnection(object):
"""Connect to the device just before the needed operation."""
def __init__(self, conn_factory, identities):
def __init__(self, conn_factory, identities, public_keys=None):
"""Create a JIT connection object."""
self.conn_factory = conn_factory
self.identities = identities
self.public_keys = util.memoize(self._public_keys) # a simple cache
self.public_keys_cache = public_keys
def _public_keys(self):
def public_keys(self):
"""Return a list of SSH public keys (in textual format)."""
conn = self.conn_factory()
return conn.export_public_keys(self.identities)
if not self.public_keys_cache:
conn = self.conn_factory()
self.public_keys_cache = conn.export_public_keys(self.identities)
return self.public_keys_cache
def parse_public_keys(self):
"""Parse SSH public keys into dictionaries."""
@ -175,8 +185,14 @@ def main(device_type):
args = create_agent_parser().parse_args()
util.setup_logging(verbosity=args.verbose)
public_keys = None
if args.identity.startswith('/'):
identities = list(parse_config(fname=args.identity))
filename = args.identity
contents = open(filename, 'rb').read().decode('ascii')
# Allow loading previously exported SSH public keys
if filename.endswith('.pub'):
public_keys = list(import_public_keys(contents))
identities = list(parse_config(contents))
else:
identities = [device.interface.Identity(
identity_str=args.identity, curve_name=args.ecdsa_curve_name)]
@ -197,7 +213,7 @@ def main(device_type):
conn = JustInTimeConnection(
conn_factory=lambda: client.Client(device_type()),
identities=identities)
identities=identities, public_keys=public_keys)
if command:
return run_server(conn=conn, command=command, debug=args.debug,
timeout=args.timeout)