gpg: parse curve OID from public key to select curve name

This commit is contained in:
Roman Zeyde 2016-10-14 23:10:05 +03:00
parent 8f19690943
commit 42813ddbb4
5 changed files with 19 additions and 17 deletions

View File

@ -204,4 +204,5 @@ def get_ecdh_curve_name(signature_curve_name):
return { return {
CURVE_NIST256: ECDH_NIST256, CURVE_NIST256: ECDH_NIST256,
CURVE_ED25519: ECDH_CURVE25519, CURVE_ED25519: ECDH_CURVE25519,
ECDH_CURVE25519: ECDH_CURVE25519,
}[signature_curve_name] }[signature_curve_name]

View File

@ -83,9 +83,15 @@ def _parse_ed25519_verifier(mpi):
return _ed25519_verify, vk return _ed25519_verify, vk
def _parse_curve25519_verifier(_):
log.warning('Curve25519 ECDH is not verified')
return None, None
SUPPORTED_CURVES = { SUPPORTED_CURVES = {
b'\x2A\x86\x48\xCE\x3D\x03\x01\x07': _parse_nist256p1_verifier, b'\x2A\x86\x48\xCE\x3D\x03\x01\x07': _parse_nist256p1_verifier,
b'\x2B\x06\x01\x04\x01\xDA\x47\x0F\x01': _parse_ed25519_verifier, b'\x2B\x06\x01\x04\x01\xDA\x47\x0F\x01': _parse_ed25519_verifier,
b'\x2B\x06\x01\x04\x01\x97\x55\x01\x05\x01': _parse_curve25519_verifier,
} }
RSA_ALGO_IDS = {1, 2, 3} RSA_ALGO_IDS = {1, 2, 3}
@ -168,6 +174,7 @@ def _parse_pubkey(stream, packet_type='pubkey'):
oid_size = stream.readfmt('B') oid_size = stream.readfmt('B')
oid = stream.read(oid_size) oid = stream.read(oid_size)
assert oid in SUPPORTED_CURVES, util.hexlify(oid) assert oid in SUPPORTED_CURVES, util.hexlify(oid)
p['curve_oid'] = oid
parser = SUPPORTED_CURVES[oid] parser = SUPPORTED_CURVES[oid]
mpi = parse_mpi(stream) mpi = parse_mpi(stream)

View File

@ -3,10 +3,11 @@ import logging
import time import time
from . import decode, device, keyring, protocol from . import decode, device, keyring, protocol
from .. import formats, util from .. import util
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def _time_format(t): def _time_format(t):
return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(t)) return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(t))
@ -109,10 +110,10 @@ def create_subkey(primary_bytes, pubkey, signer_func):
def load_from_public_key(pubkey_dict): def load_from_public_key(pubkey_dict):
"""Load correct public key from the device.""" """Load correct public key from the device."""
log.debug('pubkey_dict: %s', pubkey_dict)
user_id = pubkey_dict['user_id'] user_id = pubkey_dict['user_id']
created = pubkey_dict['created'] created = pubkey_dict['created']
curve_name = protocol.find_curve_by_algo_id(pubkey_dict['algo']) curve_name = protocol.get_curve_name_by_oid(pubkey_dict['curve_oid'])
assert curve_name in formats.SUPPORTED_CURVES
ecdh = (pubkey_dict['algo'] == protocol.ECDH_ALGO_ID) ecdh = (pubkey_dict['algo'] == protocol.ECDH_ALGO_ID)
conn = device.HardwareSigner(user_id, curve_name=curve_name) conn = device.HardwareSigner(user_id, curve_name=curve_name)

View File

@ -155,14 +155,12 @@ ECDH_ALGO_ID = 18
CUSTOM_SUBPACKET = subpacket(100, b'TREZOR-GPG') # marks "our" pubkey CUSTOM_SUBPACKET = subpacket(100, b'TREZOR-GPG') # marks "our" pubkey
def find_curve_by_algo_id(algo_id): def get_curve_name_by_oid(oid):
"""Find curve name that matches a public key algorith ID.""" """Return curve name matching specified OID, or raise KeyError."""
if algo_id == ECDH_ALGO_ID: for curve_name, info in SUPPORTED_CURVES.items():
return formats.CURVE_NIST256 if info['oid'] == oid:
return curve_name
curve_name, = [name for name, info in SUPPORTED_CURVES.items() raise KeyError('Unknown OID: {!r}'.format(oid))
if info['algo_id'] == algo_id]
return curve_name
class PublicKey(object): class PublicKey(object):
@ -173,7 +171,7 @@ class PublicKey(object):
self.curve_info = SUPPORTED_CURVES[curve_name] self.curve_info = SUPPORTED_CURVES[curve_name]
self.created = int(created) # time since Epoch self.created = int(created) # time since Epoch
self.verifying_key = verifying_key self.verifying_key = verifying_key
self.ecdh = ecdh self.ecdh = bool(ecdh)
if ecdh: if ecdh:
self.algo_id = ECDH_ALGO_ID self.algo_id = ECDH_ALGO_ID
self.ecdh_packet = b'\x03\x01\x08\x07' self.ecdh_packet = b'\x03\x01\x08\x07'

View File

@ -30,11 +30,6 @@ def test_mpi():
assert protocol.mpi(0x123) == b'\x00\x09\x01\x23' assert protocol.mpi(0x123) == b'\x00\x09\x01\x23'
def test_find():
assert protocol.find_curve_by_algo_id(19) == formats.CURVE_NIST256
assert protocol.find_curve_by_algo_id(22) == formats.CURVE_ED25519
def test_armor(): def test_armor():
data = bytearray(range(256)) data = bytearray(range(256))
assert protocol.armor(data, 'TEST') == '''-----BEGIN PGP TEST----- assert protocol.armor(data, 'TEST') == '''-----BEGIN PGP TEST-----