1+ from pydantic import ValidationError
12import pytest
23
3- from labelbox .data .annotation_types .metrics . aggregations import MetricAggregation
4- from labelbox .data .annotation_types .metrics . scalar import ScalarMetric
4+ from labelbox .data .annotation_types .metrics import ConfusionMatrixAggregation , ScalarMetricAggregation
5+ from labelbox .data .annotation_types .metrics import ConfusionMatrixMetric , ScalarMetric
56from labelbox .data .annotation_types .collection import LabelList
67from labelbox .data .annotation_types import ScalarMetric , Label , ImageData
78
@@ -30,23 +31,28 @@ def test_legacy_scalar_metric():
3031 'uid' : None
3132 }
3233 assert label .dict () == expected
33- next (LabelList ([label ])).dict () == expected
34+ assert next (LabelList ([label ])).dict () == expected
3435
3536
3637# TODO: Test with confidence
3738
38- @pytest .mark .parametrize ('feature_name,subclass_name,aggregation' , [
39- ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
40- ("cat" , None , MetricAggregation .ARITHMETIC_MEAN ),
41- (None , None , MetricAggregation .ARITHMETIC_MEAN ),
42- (None , None , None ),
43- ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
44- ("cat" , None , MetricAggregation .HARMONIC_MEAN ),
45- (None , None , MetricAggregation .GEOMETRIC_MEAN ),
46- (None , None , MetricAggregation .SUM )
39+
40+ @pytest .mark .parametrize ('feature_name,subclass_name,aggregation,value' , [
41+ ("cat" , "orange" , ScalarMetricAggregation .ARITHMETIC_MEAN , 0.5 ),
42+ ("cat" , None , ScalarMetricAggregation .ARITHMETIC_MEAN , 0.5 ),
43+ (None , None , ScalarMetricAggregation .ARITHMETIC_MEAN , 0.5 ),
44+ (None , None , None , 0.5 ),
45+ ("cat" , "orange" , ScalarMetricAggregation .ARITHMETIC_MEAN , 0.5 ),
46+ ("cat" , None , ScalarMetricAggregation .HARMONIC_MEAN , 0.5 ),
47+ (None , None , ScalarMetricAggregation .GEOMETRIC_MEAN , 0.5 ),
48+ (None , None , ScalarMetricAggregation .SUM , 0.5 ),
49+ ("cat" , "orange" , ScalarMetricAggregation .ARITHMETIC_MEAN , {
50+ 0.1 : 0.2 ,
51+ 0.3 : 0.5 ,
52+ 0.4 : 0.8
53+ }),
4754])
48- def test_custom_scalar_metric (feature_name , subclass_name , aggregation ):
49- value = 0.5
55+ def test_custom_scalar_metric (feature_name , subclass_name , aggregation , value ):
5056 kwargs = {'aggregation' : aggregation } if aggregation is not None else {}
5157 metric = ScalarMetric (metric_name = "iou" ,
5258 value = value ,
@@ -77,36 +83,37 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
7783 ** ({
7884 'subclass_name' : subclass_name
7985 } if subclass_name else {}), 'aggregation' :
80- aggregation or MetricAggregation .ARITHMETIC_MEAN ,
86+ aggregation or ScalarMetricAggregation .ARITHMETIC_MEAN ,
8187 'extra' : {}
8288 }],
8389 'extra' : {},
8490 'uid' : None
8591 }
86- assert label .dict () == expected
87- next (LabelList ([label ])).dict () == expected
88-
8992
93+ assert label .dict () == expected
94+ assert next (LabelList ([label ])).dict () == expected
9095
9196
92- @pytest .mark .parametrize ('feature_name,subclass_name,aggregation' , [
93- ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
94- ("cat" , None , MetricAggregation .ARITHMETIC_MEAN ),
95- (None , None , MetricAggregation .ARITHMETIC_MEAN ),
96- (None , None , None ),
97- ("cat" , "orange" , MetricAggregation .ARITHMETIC_MEAN ),
98- ("cat" , None , MetricAggregation .HARMONIC_MEAN ),
99- (None , None , MetricAggregation .GEOMETRIC_MEAN ),
100- (None , None , MetricAggregation .SUM ),
97+ @pytest .mark .parametrize ('feature_name,subclass_name,aggregation,value' , [
98+ ("cat" , "orange" , ConfusionMatrixAggregation .CONFUSION_MATRIX ,
99+ (0 , 1 , 2 , 3 )),
100+ ("cat" , None , ConfusionMatrixAggregation .CONFUSION_MATRIX , (0 , 1 , 2 , 3 )),
101+ (None , None , ConfusionMatrixAggregation .CONFUSION_MATRIX , (0 , 1 , 2 , 3 )),
102+ (None , None , None , (0 , 1 , 2 , 3 )),
103+ ("cat" , "orange" , ConfusionMatrixAggregation .CONFUSION_MATRIX , {
104+ 0.1 : (0 , 1 , 2 , 3 ),
105+ 0.3 : (0 , 1 , 2 , 3 ),
106+ 0.4 : (0 , 1 , 2 , 3 )
107+ }),
101108])
102- def test_custom_scalar_metric (feature_name , subclass_name , aggregation ):
103- value = 0.5
109+ def test_custom_confusison_matrix_metric (feature_name , subclass_name ,
110+ aggregation , value ):
104111 kwargs = {'aggregation' : aggregation } if aggregation is not None else {}
105- metric = ScalarMetric (metric_name = "iou " ,
106- value = value ,
107- feature_name = feature_name ,
108- subclass_name = subclass_name ,
109- ** kwargs )
112+ metric = ConfusionMatrixMetric (metric_name = "confusion_matrix_50_pct_iou " ,
113+ value = value ,
114+ feature_name = feature_name ,
115+ subclass_name = subclass_name ,
116+ ** kwargs )
110117 assert metric .value == value
111118
112119 label = Label (data = ImageData (uid = "ckrmd9q8g000009mg6vej7hzg" ),
@@ -124,18 +131,58 @@ def test_custom_scalar_metric(feature_name, subclass_name, aggregation):
124131 'value' :
125132 value ,
126133 'metric_name' :
127- 'iou ' ,
134+ 'confusion_matrix_50_pct_iou ' ,
128135 ** ({
129136 'feature_name' : feature_name
130137 } if feature_name else {}),
131138 ** ({
132139 'subclass_name' : subclass_name
133140 } if subclass_name else {}), 'aggregation' :
134- aggregation or MetricAggregation . ARITHMETIC_MEAN ,
141+ aggregation or ConfusionMatrixAggregation . CONFUSION_MATRIX ,
135142 'extra' : {}
136143 }],
137144 'extra' : {},
138145 'uid' : None
139146 }
140147 assert label .dict () == expected
141- next (LabelList ([label ])).dict () == expected
148+ assert next (LabelList ([label ])).dict () == expected
149+
150+
151+ def test_name_exists ():
152+ # Name is only required for ConfusionMatrixMetric for now.
153+ with pytest .raises (ValidationError ) as exc_info :
154+ metric = ConfusionMatrixMetric (value = [0 , 1 , 2 , 3 ])
155+ assert "field required (type=value_error.missing)" in str (exc_info .value )
156+
157+
158+ def test_invalid_aggregations ():
159+ with pytest .raises (ValidationError ) as exc_info :
160+ metric = ScalarMetric (
161+ metric_name = "invalid aggregation" ,
162+ value = 0.1 ,
163+ aggregation = ConfusionMatrixAggregation .CONFUSION_MATRIX )
164+ assert "value is not a valid enumeration member" in str (exc_info .value )
165+ with pytest .raises (ValidationError ) as exc_info :
166+ metric = ConfusionMatrixMetric (metric_name = "invalid aggregation" ,
167+ value = [0 , 1 , 2 , 3 ],
168+ aggregation = ScalarMetricAggregation .SUM )
169+ assert "value is not a valid enumeration member" in str (exc_info .value )
170+
171+
172+ def test_invalid_number_of_confidence_scores ():
173+ with pytest .raises (ValidationError ) as exc_info :
174+ metric = ScalarMetric (metric_name = "too few scores" , value = {0.1 : 0.1 })
175+ assert "Number of confidence scores must be greater" in str (exc_info .value )
176+ with pytest .raises (ValidationError ) as exc_info :
177+ metric = ConfusionMatrixMetric (metric_name = "too few scores" ,
178+ value = {0.1 : [0 , 1 , 2 , 3 ]})
179+ assert "Number of confidence scores must be greater" in str (exc_info .value )
180+ with pytest .raises (ValidationError ) as exc_info :
181+ metric = ScalarMetric (metric_name = "too many scores" ,
182+ value = {i / 20. : 0.1 for i in range (20 )})
183+ assert "Number of confidence scores must be greater" in str (exc_info .value )
184+ with pytest .raises (ValidationError ) as exc_info :
185+ metric = ConfusionMatrixMetric (
186+ metric_name = "too many scores" ,
187+ value = {i / 20. : [0 , 1 , 2 , 3 ] for i in range (20 )})
188+ assert "Number of confidence scores must be greater" in str (exc_info .value )
0 commit comments