You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
5.7 KiB
Python
209 lines
5.7 KiB
Python
import io
|
|
import struct
|
|
import hashlib
|
|
import ecdsa
|
|
import base64
|
|
|
|
import logging
|
|
log = logging.getLogger(__name__)
|
|
|
|
def send(conn, data, fmt=None):
|
|
if fmt:
|
|
data = struct.pack(fmt, *data)
|
|
conn.sendall(data)
|
|
|
|
def recv(conn, size):
|
|
try:
|
|
fmt = size
|
|
size = struct.calcsize(fmt)
|
|
except TypeError:
|
|
fmt = None
|
|
try:
|
|
_read = conn.recv
|
|
except AttributeError:
|
|
_read = conn.read
|
|
|
|
res = io.BytesIO()
|
|
while size > 0:
|
|
buf = _read(size)
|
|
if not buf:
|
|
raise EOFError
|
|
size = size - len(buf)
|
|
res.write(buf)
|
|
res = res.getvalue()
|
|
if fmt:
|
|
return struct.unpack(fmt, res)
|
|
else:
|
|
return res
|
|
|
|
|
|
def read_frame(conn):
|
|
size, = recv(conn, '>L')
|
|
return recv(conn, size)
|
|
|
|
def bytes2num(s):
|
|
res = 0
|
|
for i, c in enumerate(reversed(bytearray(s))):
|
|
res += c << (i * 8)
|
|
return res
|
|
|
|
|
|
def parse_pubkey(blob):
|
|
s = io.BytesIO(blob)
|
|
key_type = read_frame(s)
|
|
log.debug('key type: %s', key_type)
|
|
curve = read_frame(s)
|
|
log.debug('curve name: %s', curve)
|
|
point = read_frame(s)
|
|
_type, point = point[:1], point[1:]
|
|
assert _type == DER_OCTET_STRING
|
|
size = len(point) // 2
|
|
assert len(point) == 2 * size
|
|
coords = map(bytes2num, [point[:size], point[size:]])
|
|
log.debug('coordinates: %s', coords)
|
|
fp = fingerprint(blob)
|
|
result = {
|
|
'point': tuple(coords), 'curve': curve,
|
|
'fingerprint': fp,
|
|
'type': key_type,
|
|
'blob': blob, 'size': size
|
|
}
|
|
return result
|
|
|
|
def list_keys(c):
|
|
send(c, [0x1, 0xB], '>LB')
|
|
buf = io.BytesIO(read_frame(c))
|
|
assert recv(buf, '>B') == (0xC,)
|
|
num, = recv(buf, '>L')
|
|
for i in range(num):
|
|
k = parse_pubkey(read_frame(buf))
|
|
k['comment'] = read_frame(buf)
|
|
yield k
|
|
|
|
def frame(*msgs):
|
|
res = io.BytesIO()
|
|
for msg in msgs:
|
|
res.write(msg)
|
|
msg = res.getvalue()
|
|
return pack('L', len(msg)) + msg
|
|
|
|
SSH_AGENTC_REQUEST_RSA_IDENTITIES = 1
|
|
SSH_AGENT_RSA_IDENTITIES_ANSWER = 2
|
|
|
|
SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
|
|
|
|
SSH2_AGENTC_REQUEST_IDENTITIES = 11
|
|
SSH2_AGENT_IDENTITIES_ANSWER = 12
|
|
SSH2_AGENTC_SIGN_REQUEST = 13
|
|
SSH2_AGENT_SIGN_RESPONSE = 14
|
|
SSH2_AGENTC_ADD_IDENTITY = 17
|
|
SSH2_AGENTC_REMOVE_IDENTITY = 18
|
|
SSH2_AGENTC_REMOVE_ALL_IDENTITIES = 19
|
|
|
|
def pack(fmt, *args):
|
|
return struct.pack('>' + fmt, *args)
|
|
|
|
def legacy_pubs(buf, keys, signer):
|
|
code = pack('B', SSH_AGENT_RSA_IDENTITIES_ANSWER)
|
|
num = pack('L', 0) # no SSH v1 keys
|
|
return frame(code, num)
|
|
|
|
def list_pubs(buf, keys, signer):
|
|
code = pack('B', SSH2_AGENT_IDENTITIES_ANSWER)
|
|
num = pack('L', len(keys))
|
|
log.debug('available keys: %s', [k['name'] for k in keys])
|
|
for i, k in enumerate(keys):
|
|
log.debug('%2d) %s', i+1, k['fingerprint'])
|
|
pubs = [frame(k['blob']) + frame(k['name']) for k in keys]
|
|
return frame(code, num, *pubs)
|
|
|
|
def fingerprint(blob):
|
|
digest = hashlib.md5(blob).digest()
|
|
return ':'.join('{:02x}'.format(c) for c in bytearray(digest))
|
|
|
|
def num2bytes(value, size):
|
|
res = []
|
|
for i in range(size):
|
|
res.append(value & 0xFF)
|
|
value = value >> 8
|
|
assert value == 0
|
|
return bytearray(list(reversed(res)))
|
|
|
|
def sign_message(buf, keys, signer):
|
|
key = parse_pubkey(read_frame(buf))
|
|
log.debug('looking for %s', key['fingerprint'])
|
|
blob = read_frame(buf)
|
|
|
|
for k in keys:
|
|
if (k['fingerprint']) == (key['fingerprint']):
|
|
log.debug('using key %r (%s)', k['name'], k['fingerprint'])
|
|
key = k
|
|
break
|
|
else:
|
|
raise ValueError('key not found')
|
|
|
|
log.debug('signing %d-byte blob', len(blob))
|
|
r, s = signer(label=k['name'], blob=blob)
|
|
signature = (r, s)
|
|
|
|
log.debug('signature: %s', signature)
|
|
|
|
curve = ecdsa.curves.NIST256p
|
|
point = ecdsa.ellipticcurve.Point(curve.curve, *key['point'])
|
|
vk = ecdsa.VerifyingKey.from_public_point(point, curve, hashlib.sha256)
|
|
success = vk.verify(signature=signature, data=blob,
|
|
sigdecode=lambda sig, _: sig)
|
|
log.info('signature status: %s', 'OK' if success else 'ERROR')
|
|
if not success:
|
|
raise ValueError('invalid signature')
|
|
|
|
sig_bytes = io.BytesIO()
|
|
for x in signature:
|
|
sig_bytes.write(frame(b'\x00' + num2bytes(x, key['size'])))
|
|
sig_bytes = sig_bytes.getvalue()
|
|
log.debug('signature size: %d bytes', len(sig_bytes))
|
|
|
|
data = frame(frame(key['type']), frame(sig_bytes))
|
|
code = pack('B', SSH2_AGENT_SIGN_RESPONSE)
|
|
return frame(code, data)
|
|
|
|
handlers = {
|
|
SSH_AGENTC_REQUEST_RSA_IDENTITIES: legacy_pubs,
|
|
SSH2_AGENTC_REQUEST_IDENTITIES: list_pubs,
|
|
SSH2_AGENTC_SIGN_REQUEST: sign_message,
|
|
}
|
|
|
|
def handle_connection(conn, keys, signer):
|
|
try:
|
|
log.debug('welcome agent')
|
|
while True:
|
|
msg = read_frame(conn)
|
|
buf = io.BytesIO(msg)
|
|
code, = recv(buf, '>B')
|
|
log.debug('request: %d bytes', len(msg))
|
|
handler = handlers[code]
|
|
log.debug('calling %s()', handler.__name__)
|
|
reply = handler(buf=buf, keys=keys, signer=signer)
|
|
log.debug('reply: %d bytes', len(reply))
|
|
send(conn, reply)
|
|
except EOFError:
|
|
log.debug('goodbye agent')
|
|
except:
|
|
log.exception('error')
|
|
raise
|
|
|
|
DER_OCTET_STRING = b'\x04'
|
|
|
|
def load_public_key(filename):
|
|
with open(filename) as f:
|
|
return parse_public_key(f.read())
|
|
|
|
def parse_public_key(data):
|
|
file_type, base64blob, name = data.split()
|
|
blob = base64.b64decode(base64blob)
|
|
result = parse_pubkey(blob)
|
|
result['name'] = name.encode('ascii')
|
|
assert result['type'] == file_type.encode('ascii')
|
|
log.debug('loaded %s %s', file_type, result['fingerprint'])
|
|
return result
|