Skip to content

Commit e623012

Browse files
Merge pull request #565 from Labelbox/kkim/AL-2218
[AL-2218] Adding optional metadata_fields to create_data_row()
2 parents 019ec7d + 69400d0 commit e623012

File tree

11 files changed

+186
-23
lines changed

11 files changed

+186
-23
lines changed

labelbox/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
logger = logging.getLogger(__name__)
3838

3939
_LABELBOX_API_KEY = "LABELBOX_API_KEY"
40+
_DATAROW_METADATA_CREATE_ERROR = "Failed to validate the metadata"
4041

4142

4243
class Client:
@@ -90,6 +91,7 @@ def __init__(self,
9091
'Authorization': 'Bearer %s' % api_key,
9192
'X-User-Agent': f'python-sdk {SDK_VERSION}'
9293
}
94+
self._data_row_metadata_ontology = None
9395

9496
@retry.Retry(predicate=retry.if_exception_type(
9597
labelbox.exceptions.InternalServerError))
@@ -269,6 +271,8 @@ def get_error_status_code(error):
269271

270272
if get_error_status_code(internal_server_error) == 400:
271273
raise labelbox.exceptions.InvalidQueryError(message)
274+
elif _DATAROW_METADATA_CREATE_ERROR in message:
275+
raise labelbox.exceptions.ResourceCreationError(message)
272276
else:
273277
raise labelbox.exceptions.InternalServerError(message)
274278

@@ -648,7 +652,9 @@ def get_data_row_metadata_ontology(self) -> DataRowMetadataOntology:
648652
DataRowMetadataOntology: The ontology for Data Row Metadata for an organization
649653
650654
"""
651-
return DataRowMetadataOntology(self)
655+
if self._data_row_metadata_ontology is None:
656+
self._data_row_metadata_ontology = DataRowMetadataOntology(self)
657+
return self._data_row_metadata_ontology
652658

653659
def get_model(self, model_id) -> Model:
654660
""" Gets a single Model with the given ID.

labelbox/exceptions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ class InvalidQueryError(LabelboxError):
7171
pass
7272

7373

74+
class ResourceCreationError(LabelboxError):
75+
""" Indicates that a resource could not be created in the server side
76+
due to a validation or transaction error"""
77+
pass
78+
79+
7480
class NetworkError(LabelboxError):
7581
"""Raised when an HTTPError occurs."""
7682

@@ -122,4 +128,4 @@ class MALValidationError(LabelboxError):
122128

123129
class OperationNotAllowedException(Exception):
124130
"""Raised when user does not have permissions to a resource or has exceeded usage limit"""
125-
pass
131+
pass

labelbox/orm/model.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ class Field:
3232
Attributes:
3333
field_type (Field.Type): The type of the field.
3434
name (str): name that the attribute has in client-side Python objects
35-
grapgql_name (str): name that the attribute has in queries (and in
35+
graphql_name (str): name that the attribute has in queries (and in
3636
server-side database definition).
37+
result_subquery (str): graphql query result payload for a field.
3738
"""
3839

3940
class Type(Enum):
@@ -54,6 +55,25 @@ def __init__(self, enum_cls: type):
5455
def name(self):
5556
return self.enum_cls.__name__
5657

58+
class ListType:
59+
""" Represents Field that is a list of some object.
60+
Args:
61+
list_cls (type): Type of object that list is made of.
62+
graphql_type (str): Inner object's graphql type.
63+
By default, the list_cls's name is used as the graphql type.
64+
"""
65+
66+
def __init__(self, list_cls: type, graphql_type=None):
67+
self.list_cls = list_cls
68+
if graphql_type is None:
69+
self.graphql_type = self.list_cls.__name__
70+
else:
71+
self.graphql_type = graphql_type
72+
73+
@property
74+
def name(self):
75+
return f"[{self.graphql_type}]"
76+
5777
class Order(Enum):
5878
""" Type of sort ordering. """
5979
Asc = auto()
@@ -91,10 +111,15 @@ def Enum(enum_cls: type, *args):
91111
def Json(*args):
92112
return Field(Field.Type.Json, *args)
93113

114+
@staticmethod
115+
def List(list_cls: type, graphql_type=None, **kwargs):
116+
return Field(Field.ListType(list_cls, graphql_type), **kwargs)
117+
94118
def __init__(self,
95-
field_type: Union[Type, EnumType],
119+
field_type: Union[Type, EnumType, ListType],
96120
name,
97-
graphql_name=None):
121+
graphql_name=None,
122+
result_subquery=None):
98123
""" Field init.
99124
Args:
100125
field_type (Field.Type): The type of the field.
@@ -103,12 +128,14 @@ def __init__(self,
103128
graphql_name (str): query and server-side name of a database object.
104129
If None, it is constructed from the client-side name by converting
105130
snake_case (Python convention) into camelCase (GraphQL convention).
131+
result_subquery (str): graphql query result payload for a field.
106132
"""
107133
self.field_type = field_type
108134
self.name = name
109135
if graphql_name is None:
110136
graphql_name = utils.camel_case(name)
111137
self.graphql_name = graphql_name
138+
self.result_subquery = result_subquery
112139

113140
@property
114141
def asc(self):

labelbox/orm/query.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,11 @@ def results_query_part(entity):
4141
entity (type): The entity which needs fetching.
4242
"""
4343
# Query for fields
44-
fields = [field.graphql_name for field in entity.fields()]
44+
fields = [
45+
field.result_subquery
46+
if field.result_subquery is not None else field.graphql_name
47+
for field in entity.fields()
48+
]
4549

4650
# Query for cached relationships
4751
fields.extend([

labelbox/schema/batch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,10 @@ def export_data_rows(self, timeout_seconds=120) -> Generator:
104104
response = requests.get(download_url)
105105
response.raise_for_status()
106106
reader = ndjson.reader(StringIO(response.text))
107-
return (
108-
Entity.DataRow(self.client, result) for result in reader)
107+
# TODO: Update result to parse customMetadata when resolver returns
108+
return (Entity.DataRow(self.client, {
109+
**result, 'customMetadata': []
110+
}) for result in reader)
109111
elif res["status"] == "FAILED":
110112
raise LabelboxError("Data row export failed.")
111113

labelbox/schema/data_row.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import logging
2-
from datetime import datetime
3-
from typing import List, Dict, Union, TYPE_CHECKING
2+
from typing import TYPE_CHECKING
43

54
from labelbox.orm import query
65
from labelbox.orm.db_object import DbObject, Updateable, BulkDeletable
76
from labelbox.orm.model import Entity, Field, Relationship
7+
from labelbox.schema.data_row_metadata import DataRowMetadataField # type: ignore
88

99
if TYPE_CHECKING:
1010
from labelbox import AssetAttachment
@@ -22,6 +22,7 @@ class DataRow(DbObject, Updateable, BulkDeletable):
2222
updated_at (datetime)
2323
created_at (datetime)
2424
media_attributes (dict): generated media attributes for the datarow
25+
custom_metadata (list): metadata associated with the datarow
2526
2627
dataset (Relationship): `ToOne` relationship to Dataset
2728
created_by (Relationship): `ToOne` relationship to User
@@ -34,6 +35,11 @@ class DataRow(DbObject, Updateable, BulkDeletable):
3435
updated_at = Field.DateTime("updated_at")
3536
created_at = Field.DateTime("created_at")
3637
media_attributes = Field.Json("media_attributes")
38+
custom_metadata = Field.List(
39+
DataRowMetadataField,
40+
graphql_type="DataRowCustomMetadataUpsertInput!",
41+
name="custom_metadata",
42+
result_subquery="customMetadata { value schemaId }")
3743

3844
# Relationships
3945
dataset = Relationship.ToOne("Dataset")
@@ -95,4 +101,4 @@ def create_attachment(self, attachment_type,
95101
data_row_id_param: self.uid
96102
})
97103
return Entity.AssetAttachment(self.client,
98-
res["createDataRowAttachment"])
104+
res["createDataRowAttachment"])

labelbox/schema/data_row_metadata.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def _parse_ontology(raw_ontology) -> List[DataRowMetadataSchema]:
200200

201201
return fields
202202

203+
def refresh_ontology(self):
204+
self._raw_ontology = self._get_ontology()
205+
self._build_ontology()
206+
203207
def parse_metadata(
204208
self, unparsed: List[Dict[str,
205209
List[Union[str,
@@ -221,6 +225,14 @@ def parse_metadata(
221225
for dr in unparsed:
222226
fields = []
223227
for f in dr["fields"]:
228+
if f["schemaId"] not in self.fields_by_id:
229+
# Update metadata ontology if field can't be found
230+
self.refresh_ontology()
231+
if f["schemaId"] not in self.fields_by_id:
232+
raise ValueError(
233+
f"Schema Id `{f['schemaId']}` not found in ontology"
234+
)
235+
224236
schema = self.fields_by_id[f["schemaId"]]
225237
if schema.kind == DataRowMetadataKind.enum:
226238
continue
@@ -295,9 +307,8 @@ def _batch_upsert(
295307
data_row_id=m.data_row_id,
296308
fields=list(
297309
chain.from_iterable(
298-
self._parse_upsert(m) for m in m.fields))).dict(
310+
self.parse_upsert(m) for m in m.fields))).dict(
299311
by_alias=True))
300-
301312
res = _batch_operations(_batch_upsert, items, self._batch_size)
302313
return res
303314

@@ -393,14 +404,17 @@ def _bulk_export(_data_row_ids: List[str]) -> List[DataRowMetadata]:
393404
data_row_ids,
394405
batch_size=self._batch_size)
395406

396-
def _parse_upsert(
407+
def parse_upsert(
397408
self, metadatum: DataRowMetadataField
398409
) -> List[_UpsertDataRowMetadataInput]:
399410
"""Format for metadata upserts to GQL"""
400411

401412
if metadatum.schema_id not in self.fields_by_id:
402-
raise ValueError(
403-
f"Schema Id `{metadatum.schema_id}` not found in ontology")
413+
# Update metadata ontology if field can't be found
414+
self.refresh_ontology()
415+
if metadatum.schema_id not in self.fields_by_id:
416+
raise ValueError(
417+
f"Schema Id `{metadatum.schema_id}` not found in ontology")
404418

405419
schema = self.fields_by_id[metadatum.schema_id]
406420

@@ -428,8 +442,11 @@ def _validate_delete(self, delete: DeleteDataRowMetadata):
428442
deletes = set()
429443
for schema_id in delete.fields:
430444
if schema_id not in self.fields_by_id:
431-
raise ValueError(
432-
f"Schema Id `{schema_id}` not found in ontology")
445+
# Update metadata ontology if field can't be found
446+
self.refresh_ontology()
447+
if schema_id not in self.fields_by_id:
448+
raise ValueError(
449+
f"Schema Id `{schema_id}` not found in ontology")
433450

434451
schema = self.fields_by_id[schema_id]
435452
# handle users specifying enums by adding all option enums

labelbox/schema/dataset.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import Iterable
66
import time
77
import ndjson
8-
from itertools import islice
8+
from itertools import islice, chain
99

1010
from concurrent.futures import ThreadPoolExecutor, as_completed
1111
from io import StringIO
@@ -78,6 +78,18 @@ def create_data_row(self, **kwargs) -> "DataRow":
7878
if os.path.exists(row_data):
7979
kwargs[DataRow.row_data.name] = self.client.upload_file(row_data)
8080
kwargs[DataRow.dataset.name] = self
81+
82+
# Parse metadata fields, if they are provided
83+
if DataRow.custom_metadata.name in kwargs:
84+
mdo = self.client.get_data_row_metadata_ontology()
85+
metadata_fields = kwargs[DataRow.custom_metadata.name]
86+
metadata = list(
87+
chain.from_iterable(
88+
mdo.parse_upsert(m) for m in metadata_fields))
89+
kwargs[DataRow.custom_metadata.name] = [
90+
md.dict(by_alias=True) for md in metadata
91+
]
92+
8193
return self.client._create(DataRow, kwargs)
8294

8395
def create_data_rows_sync(self, items) -> None:
@@ -397,8 +409,10 @@ def export_data_rows(self, timeout_seconds=120) -> Generator:
397409
response = requests.get(download_url)
398410
response.raise_for_status()
399411
reader = ndjson.reader(StringIO(response.text))
400-
return (
401-
Entity.DataRow(self.client, result) for result in reader)
412+
# TODO: Update result to parse customMetadata when resolver returns
413+
return (Entity.DataRow(self.client, {
414+
**result, 'customMetadata': []
415+
}) for result in reader)
402416
elif res["status"] == "FAILED":
403417
raise LabelboxError("Data row export failed.")
404418

tests/integration/test_client_errors.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,21 @@
22
import os
33
import time
44
import pytest
5+
from datetime import datetime
56

67
from labelbox import Project, Dataset, User
8+
from labelbox.schema.data_row_metadata import DataRowMetadataField
79
import labelbox.client
810
import labelbox.exceptions
911

12+
SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal"
13+
TEST_SPLIT_ID = "cko8scbz70005h2dkastwhgqt"
14+
EMBEDDING_SCHEMA_ID = "ckpyije740000yxdk81pbgjdc"
15+
TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh"
16+
CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb"
17+
IMAGE_EMBEDDING_SCHEMA_ID = "ckrzang79000008l6hb5s6za1"
18+
TEXT_EMBEDDING_SCHEMA_ID = "ckrzao09x000108l67vrcdnh3"
19+
1020

1121
def test_missing_api_key():
1222
key = os.environ.get(labelbox.client._LABELBOX_API_KEY, None)
@@ -132,3 +142,28 @@ def get(arg):
132142

133143
# Sleep at the end of this test to allow other tests to execute.
134144
time.sleep(60)
145+
146+
147+
@pytest.mark.skip("Staging environment not returning correct exception")
148+
def test_resource_creation_error(dataset, image_url):
149+
150+
def make_metadata_fields():
151+
embeddings = [0.0] * 128
152+
msg = "A message"
153+
time = datetime.utcnow()
154+
155+
fields = [
156+
DataRowMetadataField(schema_id=SPLIT_SCHEMA_ID,
157+
value=TEST_SPLIT_ID),
158+
DataRowMetadataField(schema_id=CAPTURE_DT_SCHEMA_ID, value=time),
159+
DataRowMetadataField(schema_id=TEXT_SCHEMA_ID, value=msg),
160+
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID,
161+
value=embeddings),
162+
DataRowMetadataField(schema_id=EMBEDDING_SCHEMA_ID,
163+
value=embeddings),
164+
]
165+
return fields
166+
167+
with pytest.raises(labelbox.exceptions.ResourceCreationError) as excinfo:
168+
dataset.create_data_row(row_data=image_url,
169+
custom_metadata=make_metadata_fields())

tests/integration/test_data_row_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,4 @@ def test_parse_raw_metadata(mdo):
274274

275275
for row in parsed:
276276
for field in row.fields:
277-
assert mdo._parse_upsert(field)
277+
assert mdo.parse_upsert(field)

0 commit comments

Comments
 (0)