-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtest_base.py
204 lines (146 loc) · 7.06 KB
/
test_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# -*- coding: utf-8 -*-
"""Tests for the `BaseSubmissionController`."""
import pytest
from aiida import orm
from aiida_submission_controller.base import BaseSubmissionController, get_extras_dict
pytest_plugins = ["aiida.manage.tests.pytest_fixtures"]
class ExampleSubmissionController(BaseSubmissionController):
"""Dummy submission controller implementation to test the general logic."""
n_submissions: int
"Number of submissions to be made."
def get_all_extras_to_submit(self):
"""Return a list of all the extras to be submitted."""
all_extras = [(str(i),) for i in range(self.n_submissions)]
return set(all_extras)
def get_inputs_and_processclass_from_extras(self, extras_values):
"""Skip this method."""
pass
@pytest.fixture
def get_group():
"""Generate a group containing a set of nodes.
If necessary, seal the first `n_sealed` nodes.
"""
def _get_group(label="test_group", n_nodes=0, n_sealed=0):
group = orm.Group(label=label).store()
if n_nodes == 0:
return group
extras = [str(i) for i in range(n_nodes)]
active_processes = _generate_processes("idx", extras[n_sealed:])
group.add_nodes(active_processes)
sealed_processes = _generate_processes("idx", extras[:n_sealed], seal=True)
group.add_nodes(sealed_processes)
return group
return _get_group
def _generate_processes(extra_key, extra_values, seal=False, group=None):
"""Generate a list of `WorkChainNode`s with the given extras."""
processes = []
for extra in extra_values:
process = orm.WorkChainNode().store()
process.set_extra(extra_key, extra)
if seal:
process.seal()
if group:
group.add_nodes(process)
processes.append(process)
return processes
def _seal_processes(extra_key, extra_values):
"""Seal the active processes specified via the extras."""
qb = orm.QueryBuilder().append(orm.WorkChainNode, filters={f"extras.{extra_key}": {"in": extra_values}})
for w in qb.all(flat=True):
w.seal()
def test_get_extras_dict():
"""Test the `get_extras_dict` function."""
extra_keys = ["a.1.1", "a.1.2", "b.1", "c.d"]
workchain_extras = [1, 2, 3, 4]
extras_dict = get_extras_dict(extra_keys, workchain_extras)
assert extras_dict == {"a": {"1": {"1": 1, "2": 2}}, "b": {"1": 3}, "c": {"d": 4}}
@pytest.mark.usefixtures("aiida_profile_clean")
def test_get_query(get_group):
"""Test the `get_query method."""
_ = get_group(n_nodes=5, n_sealed=3)
submission_controller = ExampleSubmissionController(
group_label="test_group", max_concurrent=1, unique_extra_keys=("idx",), n_submissions=5
)
# Check `get_process_extra_projections`
process_projections = submission_controller.get_process_extra_projections()
assert process_projections == ["extras.idx"]
qb_all = submission_controller.get_query(process_projections=process_projections)
qb_active = submission_controller.get_query(process_projections=process_projections, only_active=True)
assert qb_all.count() == 5
assert qb_active.count() == 2
@pytest.mark.usefixtures("aiida_profile_clean")
def test_get_all_submitted_pks(get_group):
"""Test the `get_all_submitted_pks` method."""
group = get_group(n_nodes=5)
submission_controller = ExampleSubmissionController(
group_label="test_group", max_concurrent=1, unique_extra_keys=("idx",), n_submissions=5
)
submitted_pks = submission_controller.get_all_submitted_pks()
assert len(submitted_pks) == 5
# test `_check_submitted_extras` as well
submitted_extras = submission_controller._check_submitted_extras()
assert len(submitted_extras) == 5
assert submitted_extras == set([("0",), ("1",), ("2",), ("3",), ("4",)])
# Check that a node without the required extra is ignored
test_node = orm.WorkChainNode().store()
group.add_nodes(test_node)
submitted_pks = submission_controller.get_all_submitted_pks()
assert len(submitted_pks) == 5
# Add the required extra
test_node.set_extra("idx", 6)
submitted_pks = submission_controller.get_all_submitted_pks()
assert len(submitted_pks) == 6
@pytest.mark.usefixtures("aiida_profile_clean")
def test_get_all_submitted_processes(get_group):
"""Test the `get_all_submitted_processes` method."""
_ = get_group(n_nodes=5, n_sealed=3)
submission_controller = ExampleSubmissionController(
group_label="test_group", max_concurrent=1, unique_extra_keys=("idx",), n_submissions=5
)
submitted_processes_all = submission_controller.get_all_submitted_processes()
assert len(submitted_processes_all) == 5
assert set(submitted_processes_all.keys()) == set([("0",), ("1",), ("2",), ("3",), ("4",)])
# Only active processes
submitted_processes_active = submission_controller.get_all_submitted_processes(only_active=True)
assert len(submitted_processes_active) == 2
assert set(submitted_processes_active.keys()) == set([("3",), ("4",)])
# Use this setup to also test `_count_active_in_group`
assert submission_controller._count_active_in_group() == 2
@pytest.mark.usefixtures("aiida_profile_clean")
def test_submit_new_batch(get_group):
"""Test the `submit_new_batch` method."""
group = get_group(n_nodes=5, n_sealed=3)
submission_controller = ExampleSubmissionController(
group_label="test_group", max_concurrent=2, unique_extra_keys=("idx",), n_submissions=10
)
# Initial state: 2 active processes, 3 sealed processes and 5 to run
assert submission_controller.num_active_slots == 2
assert submission_controller.num_available_slots == 0
assert submission_controller.num_to_run == 5
assert submission_controller.num_already_run == 5
submitted = submission_controller.submit_new_batch(dry_run=True, sort=True)
# No available slots
assert len(submitted) == 0
# Seal one active processes
_seal_processes("idx", ["3"])
assert submission_controller.num_active_slots == 1
submitted = submission_controller.submit_new_batch(dry_run=True, sort=True)
assert len(submitted) == 1
assert set(submitted.keys()) == set([("5",)])
# Generate dummy process and seal the active ones
_seal_processes("idx", ["4"])
_generate_processes("idx", ["5"], seal=True, group=group)
# Submit new batches until all are submitted
submitted = submission_controller.submit_new_batch(dry_run=True, sort=True)
assert len(submitted) == 2
_generate_processes("idx", ["6", "7"], seal=True, group=group)
submitted = submission_controller.submit_new_batch(dry_run=True, sort=True)
assert len(submitted) == 2
_generate_processes("idx", ["8", "9"], seal=True, group=group)
# Check final state --> all processes are submitted
submitted = submission_controller.submit_new_batch(dry_run=True, sort=True)
assert len(submitted) == 0
assert submission_controller.num_active_slots == 0
assert submission_controller.num_available_slots == 2
assert submission_controller.num_to_run == 0
assert submission_controller.num_already_run == 10