diff --git a/invenio_uploadbyurl/models.py b/invenio_uploadbyurl/models.py index 88c405894b58a5f8d6255f7e7d24956a09f17ec7..f03171e5317f87fca63a89edc0f1ee141d16e2d4 100644 --- a/invenio_uploadbyurl/models.py +++ b/invenio_uploadbyurl/models.py @@ -143,8 +143,8 @@ class SSHKey(db.Model): """Only allow one key per user and remote_server.""" @classmethod - def create(cls, user_id, private_key, username, - remote_server_id=None, valid_until=None, keytype='rsa'): + def create(cls, private_key, username, user, remote_server, + valid_until=None, keytype='rsa'): """Create the SSH key pair.""" # Check if key already exists with db.session.begin_nested(): @@ -154,8 +154,8 @@ class SSHKey(db.Model): valid_until=valid_until, keytype=keytype, username=username, - user_id=user_id, - remote_server_id=remote_server_id, + user=user, + remote_server=remote_server, ) db.session.add(obj) return obj diff --git a/tests/conftest.py b/tests/conftest.py index 2d2b944adfd05a310a48d3023e0de5c6fe21c886..3d15a1fa9fd30d21c75eb20bc221b69143b334c0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -212,10 +212,10 @@ def remote(db, user, user2): server_address=server_name) db.session.add(remote1) - SSHKey.create(user_id=user.id, private_key=key_str_rsa.getvalue(), - username='foo', remote_server_id=remote1.id) - SSHKey.create(user_id=user2[0].id, private_key=key_str_ecdsa.getvalue(), - username='foo', remote_server_id=remote1.id, keytype='ecdsa') + SSHKey.create(private_key=key_str_rsa.getvalue(), + username='foo', user=user, remote_server=remote1) + SSHKey.create(private_key=key_str_ecdsa.getvalue(), username='foo', + user=user2[0], remote_server=remote1, keytype='ecdsa') yield remote1 diff --git a/tests/test_models.py b/tests/test_models.py index b69aebb84a3c35ca71b1f9b201ca93c5aeed8227..7959f939569c0aed39b5b546f79de4c82bd0b671 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -90,24 +90,21 @@ def test_ssh_key_pair(app, db, user, sshkeys): assert RemoteServer.get_by_name('test1').name == 'test1' # Create ssh key pair - remote_server_id1 = RemoteServer.get_by_name('test1').id - remote_server_id2 = RemoteServer.get_by_name('test2').id - keypair = SSHKey.create(user.id, sshkeys[0], - 'username', remote_server_id1) + remote_server1 = RemoteServer.get_by_name('test1') + remote_server2 = RemoteServer.get_by_name('test2') + keypair = SSHKey.create(sshkeys[0], 'username', user, remote_server1) assert keypair.user.id == user.id - keypair = SSHKey.create(user.id, sshkeys[0], - 'username', remote_server_id2) - assert keypair.remote_server.id == remote_server_id2 + keypair = SSHKey.create(sshkeys[0], 'username', user, remote_server2) + assert keypair.remote_server.id == remote_server2.id # test unique constraint with pytest.raises(IntegrityError): - keypair = SSHKey.create(user.id, sshkeys[0], - 'username', remote_server_id1) + keypair = SSHKey.create(sshkeys[0], 'username', user, remote_server1) # test key retrieval - key = SSHKey.get(user.id, remote_server_id1) + key = SSHKey.get(user.id, remote_server1.id) assert key.user.id == user.id - assert key.remote_server.id == remote_server_id1 + assert key.remote_server.id == remote_server1.id assert key.private_key == sshkeys[0] # test key update