Skip to content

Commit 8bb5363

Browse files
authored
[SDK-428] Add integration tests for ExportTask (#1300)
2 parents e2c4aca + bb78630 commit 8bb5363

14 files changed

+847
-62
lines changed
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,8 @@ def model_run_with_data_rows(client, configured_project_with_ontology,
314314
model_run_predictions, model_run,
315315
wait_for_label_processing):
316316
configured_project_with_ontology.enable_model_assisted_labeling()
317+
use_data_row_ids = [p['dataRow']['id'] for p in model_run_predictions]
318+
model_run.upsert_data_rows(use_data_row_ids)
317319

318320
upload_task = LabelImport.create_from_objects(
319321
client, configured_project_with_ontology.uid,
@@ -326,7 +328,7 @@ def model_run_with_data_rows(client, configured_project_with_ontology,
326328
labels = wait_for_label_processing(configured_project_with_ontology)
327329
label_ids = [label.uid for label in labels]
328330
model_run.upsert_labels(label_ids)
329-
yield model_run
331+
yield model_run, labels
330332
model_run.delete()
331333
# TODO: Delete resources when that is possible ..
332334

tests/integration/export_v2/test_export_data_rows.py renamed to tests/integration/export/legacy/test_export_data_rows.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,29 @@
11
import time
22
from labelbox import DataRow
3-
from labelbox.schema.media_type import MediaType
43

54

65
def test_export_data_rows(client, data_row, wait_for_data_row_processing):
76
# Ensure created data rows are indexed
87
data_row = wait_for_data_row_processing(client, data_row)
98
time.sleep(7) # temp fix for ES indexing delay
10-
params = {
11-
"include_performance_details": True,
12-
"include_labels": True,
13-
"media_type_override": MediaType.Image,
14-
"project_details": True,
15-
"data_row_details": True
16-
}
179

1810
task = DataRow.export_v2(client=client, data_rows=[data_row])
1911
task.wait_till_done()
2012
assert task.status == "COMPLETE"
2113
assert task.errors is None
2214
assert len(task.result) == 1
23-
assert task.result[0]['data_row']['id'] == data_row.uid
15+
assert task.result[0]["data_row"]["id"] == data_row.uid
2416

2517
task = DataRow.export_v2(client=client, data_rows=[data_row.uid])
2618
task.wait_till_done()
2719
assert task.status == "COMPLETE"
2820
assert task.errors is None
2921
assert len(task.result) == 1
30-
assert task.result[0]['data_row']['id'] == data_row.uid
22+
assert task.result[0]["data_row"]["id"] == data_row.uid
3123

3224
task = DataRow.export_v2(client=client, global_keys=[data_row.global_key])
3325
task.wait_till_done()
3426
assert task.status == "COMPLETE"
3527
assert task.errors is None
3628
assert len(task.result) == 1
37-
assert task.result[0]['data_row']['id'] == data_row.uid
29+
assert task.result[0]["data_row"]["id"] == data_row.uid
File renamed without changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import time
2+
3+
4+
def _model_run_export_v2_results(model_run, task_name, params, num_retries=5):
5+
"""Export model run results and retry if no results are returned."""
6+
while (num_retries > 0):
7+
task = model_run.export_v2(task_name, params=params)
8+
assert task.name == task_name
9+
task.wait_till_done()
10+
assert task.status == "COMPLETE"
11+
assert task.errors is None
12+
task_results = task.result
13+
if len(task_results) == 0:
14+
num_retries -= 1
15+
time.sleep(5)
16+
else:
17+
return task_results
18+
return []
19+
20+
21+
def test_model_run_export_v2(model_run_with_data_rows):
22+
model_run, labels = model_run_with_data_rows
23+
label_ids = [label.uid for label in labels]
24+
expected_data_rows = list(model_run.model_run_data_rows())
25+
26+
task_name = "test_task"
27+
params = {"media_attributes": True, "predictions": True}
28+
task_results = _model_run_export_v2_results(model_run, task_name, params)
29+
assert len(task_results) == len(expected_data_rows)
30+
31+
for task_result in task_results:
32+
# Check export param handling
33+
assert 'media_attributes' in task_result and task_result[
34+
'media_attributes'] is not None
35+
exported_model_run = task_result['experiments'][
36+
model_run.model_id]['runs'][model_run.uid]
37+
task_label_ids_set = set(
38+
map(lambda label: label['id'], exported_model_run['labels']))
39+
task_prediction_ids_set = set(
40+
map(lambda prediction: prediction['id'],
41+
exported_model_run['predictions']))
42+
for label_id in task_label_ids_set:
43+
assert label_id in label_ids
44+
for prediction_id in task_prediction_ids_set:
45+
assert prediction_id in label_ids
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import json
2+
import time
3+
4+
import pytest
5+
6+
from labelbox import DataRow, ExportTask, StreamType
7+
8+
9+
class TestExportDataRow:
10+
11+
def test_with_data_row_object(self, client, data_row,
12+
wait_for_data_row_processing):
13+
data_row = wait_for_data_row_processing(client, data_row)
14+
time.sleep(7) # temp fix for ES indexing delay
15+
export_task = DataRow.export(
16+
client=client,
17+
data_rows=[data_row],
18+
task_name="TestExportDataRow:test_with_data_row_object",
19+
)
20+
export_task.wait_till_done()
21+
assert export_task.status == "COMPLETE"
22+
assert isinstance(export_task, ExportTask)
23+
assert export_task.has_result()
24+
assert export_task.has_errors() is False
25+
assert export_task.get_total_file_size(
26+
stream_type=StreamType.RESULT) > 0
27+
assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1
28+
assert (json.loads(list(export_task.get_stream())[0].json_str)
29+
["data_row"]["id"] == data_row.uid)
30+
31+
def test_with_id(self, client, data_row, wait_for_data_row_processing):
32+
data_row = wait_for_data_row_processing(client, data_row)
33+
time.sleep(7) # temp fix for ES indexing delay
34+
export_task = DataRow.export(client=client,
35+
data_rows=[data_row.uid],
36+
task_name="TestExportDataRow:test_with_id")
37+
export_task.wait_till_done()
38+
assert export_task.status == "COMPLETE"
39+
assert isinstance(export_task, ExportTask)
40+
assert export_task.has_result()
41+
assert export_task.has_errors() is False
42+
assert export_task.get_total_file_size(
43+
stream_type=StreamType.RESULT) > 0
44+
assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1
45+
assert (json.loads(list(export_task.get_stream())[0].json_str)
46+
["data_row"]["id"] == data_row.uid)
47+
48+
def test_with_global_key(self, client, data_row,
49+
wait_for_data_row_processing):
50+
data_row = wait_for_data_row_processing(client, data_row)
51+
time.sleep(7) # temp fix for ES indexing delay
52+
export_task = DataRow.export(
53+
client=client,
54+
global_keys=[data_row.global_key],
55+
task_name="TestExportDataRow:test_with_global_key",
56+
)
57+
export_task.wait_till_done()
58+
assert export_task.status == "COMPLETE"
59+
assert isinstance(export_task, ExportTask)
60+
assert export_task.has_result()
61+
assert export_task.has_errors() is False
62+
assert export_task.get_total_file_size(
63+
stream_type=StreamType.RESULT) > 0
64+
assert export_task.get_total_lines(stream_type=StreamType.RESULT) == 1
65+
assert (json.loads(list(export_task.get_stream())[0].json_str)
66+
["data_row"]["id"] == data_row.uid)
67+
68+
def test_with_invalid_id(self, client):
69+
export_task = DataRow.export(
70+
client=client,
71+
data_rows=["invalid_id"],
72+
task_name="TestExportDataRow:test_with_invalid_id",
73+
)
74+
export_task.wait_till_done()
75+
assert export_task.status == "COMPLETE"
76+
assert isinstance(export_task, ExportTask)
77+
assert export_task.has_result() is False
78+
assert export_task.has_errors() is False
79+
assert export_task.get_total_file_size(
80+
stream_type=StreamType.RESULT) is None
81+
assert export_task.get_total_lines(
82+
stream_type=StreamType.RESULT) is None
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import json
2+
3+
import pytest
4+
5+
from labelbox import ExportTask, StreamType
6+
7+
8+
class TestExportDataset:
9+
10+
@pytest.mark.parametrize("data_rows", [3], indirect=True)
11+
def test_export(self, dataset, data_rows):
12+
expected_data_row_ids = [dr.uid for dr in data_rows]
13+
14+
export_task = dataset.export(task_name="TestExportDataset:test_export")
15+
export_task.wait_till_done()
16+
17+
assert export_task.status == "COMPLETE"
18+
assert isinstance(export_task, ExportTask)
19+
assert export_task.has_result()
20+
assert export_task.has_errors() is False
21+
assert export_task.get_total_file_size(
22+
stream_type=StreamType.RESULT) > 0
23+
assert export_task.get_total_lines(
24+
stream_type=StreamType.RESULT) == len(expected_data_row_ids)
25+
data_row_ids = list(
26+
map(lambda x: json.loads(x.json_str)["data_row"]["id"],
27+
export_task.get_stream()))
28+
assert data_row_ids.sort() == expected_data_row_ids.sort()
29+
30+
@pytest.mark.parametrize("data_rows", [3], indirect=True)
31+
def test_with_data_row_filter(self, dataset, data_rows):
32+
datarow_filter_size = 3
33+
expected_data_row_ids = [dr.uid for dr in data_rows
34+
][:datarow_filter_size]
35+
filters = {"data_row_ids": expected_data_row_ids}
36+
37+
export_task = dataset.export(
38+
filters=filters,
39+
task_name="TestExportDataset:test_with_data_row_filter")
40+
export_task.wait_till_done()
41+
42+
assert export_task.status == "COMPLETE"
43+
assert isinstance(export_task, ExportTask)
44+
assert export_task.has_result()
45+
assert export_task.has_errors() is False
46+
assert export_task.get_total_file_size(
47+
stream_type=StreamType.RESULT) > 0
48+
assert export_task.get_total_lines(
49+
stream_type=StreamType.RESULT) == datarow_filter_size
50+
data_row_ids = list(
51+
map(lambda x: json.loads(x.json_str)["data_row"]["id"],
52+
export_task.get_stream()))
53+
assert data_row_ids.sort() == expected_data_row_ids.sort()
54+
55+
@pytest.mark.parametrize("data_rows", [3], indirect=True)
56+
def test_with_global_key_filter(self, dataset, data_rows):
57+
datarow_filter_size = 2
58+
expected_global_keys = [dr.global_key for dr in data_rows
59+
][:datarow_filter_size]
60+
filters = {"global_keys": expected_global_keys}
61+
62+
export_task = dataset.export(
63+
filters=filters,
64+
task_name="TestExportDataset:test_with_global_key_filter")
65+
export_task.wait_till_done()
66+
67+
assert export_task.status == "COMPLETE"
68+
assert isinstance(export_task, ExportTask)
69+
assert export_task.has_result()
70+
assert export_task.has_errors() is False
71+
assert export_task.get_total_file_size(
72+
stream_type=StreamType.RESULT) > 0
73+
assert export_task.get_total_lines(
74+
stream_type=StreamType.RESULT) == datarow_filter_size
75+
global_keys = list(
76+
map(lambda x: json.loads(x.json_str)["data_row"]["global_key"],
77+
export_task.get_stream()))
78+
assert global_keys.sort() == expected_global_keys.sort()

0 commit comments

Comments
 (0)