Allow loading previously exported SSH public keys from a file
This commit is contained in:
parent
53d43cba29
commit
f358ca29d4
@ -2,6 +2,7 @@
|
||||
import argparse
|
||||
import contextlib
|
||||
import functools
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@ -133,27 +134,36 @@ def handle_connection_error(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def parse_config(fname):
|
||||
def parse_config(contents):
|
||||
"""Parse config file into a list of Identity objects."""
|
||||
contents = open(fname).read()
|
||||
for identity_str, curve_name in re.findall(r'\<(.*?)\|(.*?)\>', contents):
|
||||
yield device.interface.Identity(identity_str=identity_str,
|
||||
curve_name=curve_name)
|
||||
|
||||
|
||||
def import_public_keys(contents):
|
||||
"""Load (previously exported) SSH public keys from a file's contents."""
|
||||
for line in io.StringIO(contents):
|
||||
# Verify this line represents valid SSH public key
|
||||
formats.import_public_key(line)
|
||||
yield line
|
||||
|
||||
|
||||
class JustInTimeConnection(object):
|
||||
"""Connect to the device just before the needed operation."""
|
||||
|
||||
def __init__(self, conn_factory, identities):
|
||||
def __init__(self, conn_factory, identities, public_keys=None):
|
||||
"""Create a JIT connection object."""
|
||||
self.conn_factory = conn_factory
|
||||
self.identities = identities
|
||||
self.public_keys = util.memoize(self._public_keys) # a simple cache
|
||||
self.public_keys_cache = public_keys
|
||||
|
||||
def _public_keys(self):
|
||||
def public_keys(self):
|
||||
"""Return a list of SSH public keys (in textual format)."""
|
||||
conn = self.conn_factory()
|
||||
return conn.export_public_keys(self.identities)
|
||||
if not self.public_keys_cache:
|
||||
conn = self.conn_factory()
|
||||
self.public_keys_cache = conn.export_public_keys(self.identities)
|
||||
return self.public_keys_cache
|
||||
|
||||
def parse_public_keys(self):
|
||||
"""Parse SSH public keys into dictionaries."""
|
||||
@ -175,8 +185,14 @@ def main(device_type):
|
||||
args = create_agent_parser().parse_args()
|
||||
util.setup_logging(verbosity=args.verbose)
|
||||
|
||||
public_keys = None
|
||||
if args.identity.startswith('/'):
|
||||
identities = list(parse_config(fname=args.identity))
|
||||
filename = args.identity
|
||||
contents = open(filename, 'rb').read().decode('ascii')
|
||||
# Allow loading previously exported SSH public keys
|
||||
if filename.endswith('.pub'):
|
||||
public_keys = list(import_public_keys(contents))
|
||||
identities = list(parse_config(contents))
|
||||
else:
|
||||
identities = [device.interface.Identity(
|
||||
identity_str=args.identity, curve_name=args.ecdsa_curve_name)]
|
||||
@ -197,7 +213,7 @@ def main(device_type):
|
||||
|
||||
conn = JustInTimeConnection(
|
||||
conn_factory=lambda: client.Client(device_type()),
|
||||
identities=identities)
|
||||
identities=identities, public_keys=public_keys)
|
||||
if command:
|
||||
return run_server(conn=conn, command=command, debug=args.debug,
|
||||
timeout=args.timeout)
|
||||
|
Loading…
Reference in New Issue
Block a user