Skip to content

Commit fbb05e8

Browse files
authored
Add type annotations (#381)
Add type annotations to the pywinrm library. This can help reduce type related bugs inside the library and also help callers to figure out the correct values that can be used in the public API.
1 parent 6f2ef68 commit fbb05e8

File tree

10 files changed

+207
-130
lines changed

10 files changed

+207
-130
lines changed

.github/workflows/ci.yml

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ jobs:
7979
run: |
8080
python -m black . --check
8181
python -m isort . --check-only
82+
python -m mypy .
8283
pytest -v --cov=winrm --cov-report=term-missing winrm/tests/
8384
8485
- name: upload coverage data

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@ __pycache__
2929
/winrm/tests/config.json
3030
.pytest_cache
3131
venv
32+
.mypy_cache

pyproject.toml

+36-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ include-package-data = true
5959
packages = ["winrm"]
6060

6161
[tool.setuptools.package-data]
62+
"winrm" = ["py.typed"]
6263
"winrm.tests" = ["*.ps1"]
6364

6465
[tool.setuptools.dynamic]
@@ -77,4 +78,38 @@ exclude = '''
7778
'''
7879

7980
[tool.isort]
80-
profile = "black"
81+
profile = "black"
82+
83+
[tool.mypy]
84+
exclude = "build/|winrm/tests/|winrm/vendor/"
85+
mypy_path = "$MYPY_CONFIG_FILE_DIR"
86+
python_version = "3.8"
87+
show_error_codes = true
88+
show_column_numbers = true
89+
disallow_any_unimported = true
90+
disallow_untyped_calls = true
91+
disallow_untyped_defs = true
92+
disallow_incomplete_defs = true
93+
check_untyped_defs = true
94+
disallow_untyped_decorators = true
95+
no_implicit_reexport = true
96+
warn_redundant_casts = true
97+
warn_unused_ignores = true
98+
warn_no_return = true
99+
warn_unreachable = true
100+
101+
[[tool.mypy.overrides]]
102+
module = "winrm.vendor.*"
103+
follow_imports = "skip"
104+
105+
[[tool.mypy.overrides]]
106+
module = "requests.packages.urllib3.*"
107+
ignore_missing_imports = true
108+
109+
[[tool.mypy.overrides]]
110+
module = "requests_credssp"
111+
ignore_missing_imports = true
112+
113+
[[tool.mypy.overrides]]
114+
module = "requests_ntlm"
115+
ignore_missing_imports = true

requirements-test.txt

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# this assumes the base requirements have been satisfied via setup.py
22
black == 24.4.2
33
isort == 5.13.2
4+
mypy == 1.10.0
45
pytest
56
pytest-cov
67
mock
8+
types-requests
9+
types-xmltodict

winrm/__init__.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
import collections.abc
34
import re
5+
import typing as t
46
import warnings
57
import xml.etree.ElementTree as ET
68
from base64 import b64encode
@@ -22,22 +24,22 @@
2224
class Response(object):
2325
"""Response from a remote command execution"""
2426

25-
def __init__(self, args):
27+
def __init__(self, args: tuple[bytes, bytes, int]) -> None:
2628
self.std_out, self.std_err, self.status_code = args
2729

28-
def __repr__(self):
30+
def __repr__(self) -> str:
2931
# TODO put tree dots at the end if out/err was truncated
30-
return '<Response code {0}, out "{1}", err "{2}">'.format(self.status_code, self.std_out[:20], self.std_err[:20])
32+
return '<Response code {0}, out "{1!r}", err "{2!r}">'.format(self.status_code, self.std_out[:20], self.std_err[:20])
3133

3234

3335
class Session(object):
3436
# TODO implement context manager methods
35-
def __init__(self, target, auth, **kwargs):
37+
def __init__(self, target: str, auth: tuple[str, str], **kwargs: t.Any) -> None:
3638
username, password = auth
3739
self.url = self._build_url(target, kwargs.get("transport", "plaintext"))
3840
self.protocol = Protocol(self.url, username=username, password=password, **kwargs)
3941

40-
def run_cmd(self, command, args=()):
42+
def run_cmd(self, command: str, args: collections.abc.Iterable[str | bytes] = ()) -> Response:
4143
# TODO optimize perf. Do not call open/close shell every time
4244
shell_id = self.protocol.open_shell()
4345
command_id = self.protocol.run_command(shell_id, command, args)
@@ -46,7 +48,7 @@ def run_cmd(self, command, args=()):
4648
self.protocol.close_shell(shell_id)
4749
return rs
4850

49-
def run_ps(self, script):
51+
def run_ps(self, script: str) -> Response:
5052
"""base64 encodes a Powershell script and executes the powershell
5153
encoded script command
5254
"""
@@ -59,7 +61,7 @@ def run_ps(self, script):
5961
rs.std_err = self._clean_error_msg(rs.std_err)
6062
return rs
6163

62-
def _clean_error_msg(self, msg):
64+
def _clean_error_msg(self, msg: bytes) -> bytes:
6365
"""converts a Powershell CLIXML message to a more human readable string"""
6466
# TODO prepare unit test, beautify code
6567
# if the msg does not start with this, return it as is
@@ -77,7 +79,8 @@ def _clean_error_msg(self, msg):
7779
for s in nodes:
7880
# append error msg string to result, also
7981
# the hex chars represent CRLF so we replace with newline
80-
new_msg += s.text.replace("_x000D__x000A_", "\n")
82+
if s.text:
83+
new_msg += s.text.replace("_x000D__x000A_", "\n")
8184
except Exception as e:
8285
# if any of the above fails, the msg was not true xml
8386
# print a warning and return the original string
@@ -93,7 +96,7 @@ def _clean_error_msg(self, msg):
9396
# just return the original message
9497
return msg
9598

96-
def _strip_namespace(self, xml):
99+
def _strip_namespace(self, xml: bytes) -> bytes:
97100
"""strips any namespaces from an xml string"""
98101
p = re.compile(b'xmlns=*[""][^""]*[""]')
99102
allmatches = p.finditer(xml)
@@ -102,8 +105,11 @@ def _strip_namespace(self, xml):
102105
return xml
103106

104107
@staticmethod
105-
def _build_url(target, transport):
108+
def _build_url(target: str, transport: str) -> str:
106109
match = re.match(r"(?i)^((?P<scheme>http[s]?)://)?(?P<host>[0-9a-z-_.]+)(:(?P<port>\d+))?(?P<path>(/)?(wsman)?)?", target) # NOQA
110+
if not match:
111+
raise ValueError("Invalid target URL: {0}".format(target))
112+
107113
scheme = match.group("scheme")
108114
if not scheme:
109115
# TODO do we have anything other than HTTP/HTTPS

winrm/encryption.py

+23-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import re
24
import struct
35
from urllib.parse import urlsplit
@@ -12,7 +14,7 @@ class Encryption(object):
1214
SIXTEN_KB = 16384
1315
MIME_BOUNDARY = b"--Encrypted Boundary"
1416

15-
def __init__(self, session, protocol):
17+
def __init__(self, session: requests.Session, protocol: str) -> None:
1618
"""
1719
[MS-WSMV] v30.0 2016-07-14
1820
@@ -51,7 +53,7 @@ def __init__(self, session, protocol):
5153
else:
5254
raise WinRMError("Encryption for protocol '%s' not supported in pywinrm" % protocol)
5355

54-
def prepare_encrypted_request(self, session, endpoint, message):
56+
def prepare_encrypted_request(self, session: requests.Session, endpoint: str | bytes, message: bytes) -> requests.PreparedRequest:
5557
"""
5658
Creates a prepared request to send to the server with an encrypted message
5759
and correct headers
@@ -77,28 +79,29 @@ def prepare_encrypted_request(self, session, endpoint, message):
7779

7880
request = requests.Request("POST", endpoint, data=encrypted_message)
7981
prepared_request = session.prepare_request(request)
80-
prepared_request.headers["Content-Length"] = str(len(prepared_request.body))
82+
prepared_request.headers["Content-Length"] = str(len(prepared_request.body)) if prepared_request.body else "0"
8183
prepared_request.headers["Content-Type"] = '{0};protocol="{1}";boundary="Encrypted Boundary"'.format(content_type, self.protocol_string.decode())
8284

8385
return prepared_request
8486

85-
def parse_encrypted_response(self, response):
87+
def parse_encrypted_response(self, response: requests.Response) -> bytes:
8688
"""
8789
Takes in the encrypted response from the server and decrypts it
8890
8991
:param response: The response that needs to be decrypted
9092
:return: The unencrypted message from the server
9193
"""
9294
content_type = response.headers["Content-Type"]
95+
9396
if 'protocol="{0}"'.format(self.protocol_string.decode()) in content_type:
9497
host = urlsplit(response.request.url).hostname
9598
msg = self._decrypt_response(response, host)
9699
else:
97-
msg = response.text
100+
msg = response.content
98101

99102
return msg
100103

101-
def _encrypt_message(self, message, host):
104+
def _encrypt_message(self, message: bytes, host: str | bytes | None) -> bytes:
102105
message_length = str(len(message)).encode()
103106
encrypted_stream = self._build_message(message, host)
104107

@@ -111,7 +114,7 @@ def _encrypt_message(self, message, host):
111114

112115
return message_payload
113116

114-
def _decrypt_response(self, response, host):
117+
def _decrypt_response(self, response: requests.Response, host: str | bytes | None) -> bytes:
115118
parts = response.content.split(self.MIME_BOUNDARY + b"\r\n")
116119
parts = list(filter(None, parts)) # filter out empty parts of the split
117120
message = b""
@@ -139,55 +142,55 @@ def _decrypt_response(self, response, host):
139142

140143
return message
141144

142-
def _decrypt_ntlm_message(self, encrypted_data, host):
145+
def _decrypt_ntlm_message(self, encrypted_data: bytes, host: str | bytes | None) -> bytes:
143146
signature_length = struct.unpack("<i", encrypted_data[:4])[0]
144147
signature = encrypted_data[4 : signature_length + 4]
145148
encrypted_message = encrypted_data[signature_length + 4 :]
146149

147-
message = self.session.auth.session_security.unwrap(encrypted_message, signature)
150+
message = self.session.auth.session_security.unwrap(encrypted_message, signature) # type: ignore[union-attr]
148151

149152
return message
150153

151-
def _decrypt_credssp_message(self, encrypted_data, host):
154+
def _decrypt_credssp_message(self, encrypted_data: bytes, host: str | bytes | None) -> bytes:
152155
# trailer_length = struct.unpack("<i", encrypted_data[:4])[0]
153156
encrypted_message = encrypted_data[4:]
154157

155-
credssp_context = self.session.auth.contexts[host]
158+
credssp_context = self.session.auth.contexts[host] # type: ignore[union-attr]
156159
message = credssp_context.unwrap(encrypted_message)
157160

158161
return message
159162

160-
def _decrypt_kerberos_message(self, encrypted_data, host):
163+
def _decrypt_kerberos_message(self, encrypted_data: bytes, host: str | bytes | None) -> bytes:
161164
signature_length = struct.unpack("<i", encrypted_data[:4])[0]
162165
signature = encrypted_data[4 : signature_length + 4]
163166
encrypted_message = encrypted_data[signature_length + 4 :]
164167

165-
message = self.session.auth.unwrap_winrm(host, encrypted_message, signature)
168+
message = self.session.auth.unwrap_winrm(host, encrypted_message, signature) # type: ignore[union-attr]
166169

167170
return message
168171

169-
def _build_ntlm_message(self, message, host):
170-
sealed_message, signature = self.session.auth.session_security.wrap(message)
172+
def _build_ntlm_message(self, message: bytes, host: str | bytes | None) -> bytes:
173+
sealed_message, signature = self.session.auth.session_security.wrap(message) # type: ignore[union-attr]
171174
signature_length = struct.pack("<i", len(signature))
172175

173176
return signature_length + signature + sealed_message
174177

175-
def _build_credssp_message(self, message, host):
176-
credssp_context = self.session.auth.contexts[host]
178+
def _build_credssp_message(self, message: bytes, host: str | bytes | None) -> bytes:
179+
credssp_context = self.session.auth.contexts[host] # type: ignore[union-attr]
177180
sealed_message = credssp_context.wrap(message)
178181

179182
cipher_negotiated = credssp_context.tls_connection.get_cipher_name()
180183
trailer_length = self._get_credssp_trailer_length(len(message), cipher_negotiated)
181184

182185
return struct.pack("<i", trailer_length) + sealed_message
183186

184-
def _build_kerberos_message(self, message, host):
185-
sealed_message, signature = self.session.auth.wrap_winrm(host, message)
187+
def _build_kerberos_message(self, message: bytes, host: str | bytes | None) -> bytes:
188+
sealed_message, signature = self.session.auth.wrap_winrm(host, message) # type: ignore[union-attr]
186189
signature_length = struct.pack("<i", len(signature))
187190

188191
return signature_length + signature + sealed_message
189192

190-
def _get_credssp_trailer_length(self, message_length, cipher_suite):
193+
def _get_credssp_trailer_length(self, message_length: int, cipher_suite: str) -> int:
191194
# I really don't like the way this works but can't find a better way, MS
192195
# allows you to get this info through the struct SecPkgContext_StreamSizes
193196
# but there is no GSSAPI/OpenSSL equivalent so we need to calculate it

winrm/exceptions.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,22 @@ class WinRMTransportError(Exception):
1111
"""WinRM errors specific to transport-level problems (unexpected HTTP error codes, etc)"""
1212

1313
@property
14-
def protocol(self):
14+
def protocol(self) -> str:
1515
return self.args[0]
1616

1717
@property
18-
def code(self):
18+
def code(self) -> int:
1919
return self.args[1]
2020

2121
@property
22-
def message(self):
22+
def message(self) -> str:
2323
return "Bad HTTP response returned from server. Code {0}".format(self.code)
2424

2525
@property
26-
def response_text(self):
26+
def response_text(self) -> str:
2727
return self.args[2]
2828

29-
def __str__(self):
29+
def __str__(self) -> str:
3030
return self.message
3131

3232

0 commit comments

Comments
 (0)