Skip to content

Commit 47bab9d

Browse files
author
Gareth
authored
Merge pull request #328 from Labelbox/gj/batch-mode
[DIAG-944] Batch Mode
2 parents e4bac5a + e0ed0b6 commit 47bab9d

File tree

5 files changed

+182
-17
lines changed

5 files changed

+182
-17
lines changed

labelbox/schema/project.py

Lines changed: 130 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,31 @@
1+
import enum
12
import json
2-
import time
33
import logging
4+
import time
5+
import warnings
46
from collections import namedtuple
57
from datetime import datetime, timezone
68
from pathlib import Path
7-
from typing import Dict, Union, Iterable
9+
from typing import Dict, Union, Iterable, List, Optional
810
from urllib.parse import urlparse
9-
import requests
11+
1012
import ndjson
13+
import requests
1114

1215
from labelbox import utils
13-
from labelbox.schema.data_row import DataRow
14-
from labelbox.orm import query
15-
from labelbox.schema.bulk_import_request import BulkImportRequest
1616
from labelbox.exceptions import InvalidQueryError, LabelboxError
17+
from labelbox.orm import query
1718
from labelbox.orm.db_object import DbObject, Updateable, Deletable
1819
from labelbox.orm.model import Entity, Field, Relationship
1920
from labelbox.pagination import PaginatedCollection
21+
from labelbox.schema.bulk_import_request import BulkImportRequest
22+
from labelbox.schema.data_row import DataRow
2023

2124
try:
2225
datetime.fromisoformat # type: ignore[attr-defined]
2326
except AttributeError:
2427
from backports.datetime_fromisoformat import MonkeyPatch
28+
2529
MonkeyPatch.patch_fromisoformat()
2630

2731
try:
@@ -31,6 +35,19 @@
3135

3236
logger = logging.getLogger(__name__)
3337

38+
MAX_QUEUE_BATCH_SIZE = 1000
39+
40+
41+
class QueueMode(enum.Enum):
42+
Batch = "Batch"
43+
Dataset = "Dataset"
44+
45+
46+
class QueueErrors(enum.Enum):
47+
InvalidDataRowType = 'InvalidDataRowType'
48+
AlreadyInProject = 'AlreadyInProject'
49+
HasAttachedLabel = 'HasAttachedLabel'
50+
3451

3552
class Project(DbObject, Updateable, Deletable):
3653
""" A Project is a container that includes a labeling frontend, an ontology,
@@ -79,6 +96,14 @@ class Project(DbObject, Updateable, Deletable):
7996
benchmarks = Relationship.ToMany("Benchmark", False)
8097
ontology = Relationship.ToOne("Ontology", True)
8198

99+
def update(self, **kwargs):
100+
101+
mode: Optional[QueueMode] = kwargs.pop("queue_mode", None)
102+
if mode:
103+
self._update_queue_mode(mode)
104+
105+
return super().update(**kwargs)
106+
82107
def members(self):
83108
""" Fetch all current members for this project
84109
@@ -407,14 +432,14 @@ def setup(self, labeling_frontend, labeling_frontend_options):
407432
a.k.a. project ontology. If given a `dict` it will be converted
408433
to `str` using `json.dumps`.
409434
"""
410-
organization = self.client.get_organization()
435+
411436
if not isinstance(labeling_frontend_options, str):
412437
labeling_frontend_options = json.dumps(labeling_frontend_options)
413438

414439
self.labeling_frontend.connect(labeling_frontend)
415440

416441
LFO = Entity.LabelingFrontendOptions
417-
labeling_frontend_options = self.client._create(
442+
self.client._create(
418443
LFO, {
419444
LFO.project: self,
420445
LFO.labeling_frontend: labeling_frontend,
@@ -424,6 +449,103 @@ def setup(self, labeling_frontend, labeling_frontend_options):
424449
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
425450
self.update(setup_complete=timestamp)
426451

452+
def queue(self, data_row_ids: List[str]):
453+
"""Add Data Rows to the Project queue"""
454+
455+
method = "submitBatchOfDataRows"
456+
return self._post_batch(method, data_row_ids)
457+
458+
def dequeue(self, data_row_ids: List[str]):
459+
"""Remove Data Rows from the Project queue"""
460+
461+
method = "removeBatchOfDataRows"
462+
return self._post_batch(method, data_row_ids)
463+
464+
def _post_batch(self, method, data_row_ids: List[str]):
465+
"""Post batch methods"""
466+
467+
if self.queue_mode() != QueueMode.Batch:
468+
raise ValueError("Project must be in batch mode")
469+
470+
if len(data_row_ids) > MAX_QUEUE_BATCH_SIZE:
471+
raise ValueError(
472+
f"Batch exceeds max size of {MAX_QUEUE_BATCH_SIZE}, consider breaking it into parts"
473+
)
474+
475+
query = """mutation %sPyApi($projectId: ID!, $dataRowIds: [ID!]!) {
476+
project(where: {id: $projectId}) {
477+
%s(data: {dataRowIds: $dataRowIds}) {
478+
dataRows {
479+
dataRowId
480+
error
481+
}
482+
}
483+
}
484+
}
485+
""" % (method, method)
486+
487+
res = self.client.execute(query, {
488+
"projectId": self.uid,
489+
"dataRowIds": data_row_ids
490+
})["project"][method]["dataRows"]
491+
492+
# TODO: figure out error messaging
493+
if len(data_row_ids) == len(res):
494+
raise ValueError("No dataRows were submitted successfully")
495+
496+
if len(data_row_ids) > 0:
497+
warnings.warn("Some Data Rows were not submitted successfully")
498+
499+
return res
500+
501+
def _update_queue_mode(self, mode: QueueMode) -> QueueMode:
502+
503+
if self.queue_mode() == mode:
504+
return mode
505+
506+
if mode == QueueMode.Batch:
507+
status = "ENABLED"
508+
elif mode == QueueMode.Dataset:
509+
status = "DISABLED"
510+
else:
511+
raise ValueError(
512+
"Must provide either `BATCH` or `DATASET` as a mode")
513+
514+
query_str = """mutation %s($projectId: ID!, $status: TagSetStatusInput!) {
515+
project(where: {id: $projectId}) {
516+
setTagSetStatus(input: {tagSetStatus: $status}) {
517+
tagSetStatus
518+
}
519+
}
520+
}
521+
""" % "setTagSetStatusPyApi"
522+
523+
self.client.execute(query_str, {
524+
'projectId': self.uid,
525+
'status': status
526+
})
527+
528+
return mode
529+
530+
def queue_mode(self):
531+
532+
query_str = """query %s($projectId: ID!) {
533+
project(where: {id: $projectId}) {
534+
tagSetStatus
535+
}
536+
}
537+
""" % "GetTagSetStatusPyApi"
538+
539+
status = self.client.execute(
540+
query_str, {'projectId': self.uid})["project"]["tagSetStatus"]
541+
542+
if status == "ENABLED":
543+
return QueueMode.Batch
544+
elif status == "DISABLED":
545+
return QueueMode.Dataset
546+
else:
547+
raise ValueError("Status not known")
548+
427549
def validate_labeling_parameter_overrides(self, data):
428550
for idx, row in enumerate(data):
429551
if len(row) != 3:

tests/data/annotation_types/data/test_text.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import pytest
24
from pydantic import ValidationError
35

@@ -22,11 +24,13 @@ def test_url():
2224
assert len(text) == 3541
2325

2426

25-
def test_file():
26-
file_path = "tests/data/assets/sample_text.txt"
27-
text_data = TextData(file_path=file_path)
28-
text = text_data.value
29-
assert len(text) == 3541
27+
def test_file(tmpdir):
28+
content = "foo bar baz"
29+
file = "hello.txt"
30+
dir = tmpdir.mkdir('data')
31+
dir.join(file).write(content)
32+
text_data = TextData(file_path=os.path.join(dir.strpath, file))
33+
assert len(text_data.value) == len(content)
3034

3135

3236
def test_ref():

tests/data/annotation_types/geometry/test_rectangle.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from pydantic import ValidationError
2-
import pytest
31
import cv2
2+
import pytest
3+
from pydantic import ValidationError
44

55
from labelbox.data.annotation_types import Point, Rectangle
66

@@ -18,3 +18,7 @@ def test_rectangle():
1818

1919
raster = rectangle.draw(height=32, width=32)
2020
assert (cv2.imread("tests/data/assets/rectangle.png") == raster).all()
21+
22+
xyhw = Rectangle.from_xyhw(0., 0, 10, 10)
23+
assert xyhw.start == Point(x=0, y=0.)
24+
assert xyhw.end == Point(x=10, y=10.0)

tests/integration/test_batch.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
3+
from labelbox import Dataset, Project
4+
from labelbox.schema.project import QueueMode
5+
6+
IMAGE_URL = "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000000034.jpg"
7+
8+
9+
@pytest.fixture
10+
def big_dataset(dataset: Dataset):
11+
task = dataset.create_data_rows([
12+
{
13+
"row_data": IMAGE_URL,
14+
"external_id": "my-image"
15+
},
16+
] * 250)
17+
task.wait_till_done()
18+
19+
yield dataset
20+
dataset.delete()
21+
22+
23+
def test_submit_batch(configured_project: Project, big_dataset):
24+
configured_project.update(queue_mode=QueueMode.Batch)
25+
26+
data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())]
27+
queue_res = configured_project.queue(data_rows)
28+
assert not len(queue_res)
29+
dequeue_res = configured_project.dequeue(data_rows)
30+
assert not len(dequeue_res)

tests/integration/test_project.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import json
22

3-
import requests
4-
import ndjson
53
import pytest
64

75
from labelbox import Project, LabelingFrontend
86
from labelbox.exceptions import InvalidQueryError
7+
from labelbox.schema.project import QueueMode
98

109

1110
def test_project(client, rand_gen):
@@ -107,3 +106,9 @@ def test_attach_instructions(client, project):
107106
def test_queued_data_row_export(configured_project):
108107
result = configured_project.export_queued_data_rows()
109108
assert len(result) == 1
109+
110+
111+
def test_queue_mode(configured_project: Project):
112+
assert configured_project.queue_mode() == QueueMode.Dataset
113+
configured_project.update(queue_mode=QueueMode.Batch)
114+
assert configured_project.queue_mode() == QueueMode.Batch

0 commit comments

Comments
 (0)