1- from labelbox .data .annotation_types .metrics import ScalarMetricAggregation
2- from typing import Union , Optional
1+ from typing import Optional , Union , Type
32
43from labelbox .data .annotation_types .data import ImageData , TextData
5- from labelbox .data .annotation_types .metrics import ScalarMetric
64from labelbox .data .serialization .ndjson .base import NDJsonBase
5+ from labelbox .data .annotation_types .metrics .scalar import (
6+ ScalarMetric , ScalarMetricAggregation , ScalarMetricValue ,
7+ ScalarMetricConfidenceValue )
8+ from labelbox .data .annotation_types .metrics .confusion_matrix import (
9+ ConfusionMatrixAggregation , ConfusionMatrixMetric ,
10+ ConfusionMatrixMetricValue , ConfusionMatrixMetricConfidenceValue )
711
812
9- class NDScalarMetric (NDJsonBase ):
13+ class BaseNDMetric (NDJsonBase ):
1014 metric_value : float
11- metric_name : Optional [str ]
1215 feature_name : Optional [str ] = None
1316 subclass_name : Optional [str ] = None
14- aggregation : ScalarMetricAggregation = ScalarMetricAggregation .ARITHMETIC_MEAN .value
17+
18+ class Config :
19+ use_enum_values = True
20+
21+ def dict (self , * args , ** kwargs ):
22+ res = super ().dict (* args , ** kwargs )
23+ for field in ['featureName' , 'subclassName' ]:
24+ if res [field ] is None :
25+ res .pop (field )
26+ return res
27+
28+
29+ class NDConfusionMatrixMetric (BaseNDMetric ):
30+ metric_value : Union [ConfusionMatrixMetricValue ,
31+ ConfusionMatrixMetricConfidenceValue ]
32+ metric_name : str
33+ aggregation : ConfusionMatrixAggregation
34+
35+ def to_common (self ) -> ConfusionMatrixMetric :
36+ return ConfusionMatrixMetric (value = self .metric_value ,
37+ metric_name = self .metric_name ,
38+ feature_name = self .feature_name ,
39+ subclass_name = self .subclass_name ,
40+ aggregation = self .aggregation ,
41+ extra = {'uuid' : self .uuid })
42+
43+ @classmethod
44+ def from_common (
45+ cls , metric : ConfusionMatrixMetric ,
46+ data : Union [TextData , ImageData ]) -> "NDConfusionMatrixMetric" :
47+ return cls (uuid = metric .extra .get ('uuid' ),
48+ metric_value = metric .value ,
49+ metric_name = metric .metric_name ,
50+ feature_name = metric .feature_name ,
51+ subclass_name = metric .subclass_name ,
52+ aggregation = metric .aggregation ,
53+ data_row = {'id' : data .uid })
54+
55+
56+ class NDScalarMetric (BaseNDMetric ):
57+ metric_value : Union [ScalarMetricValue , ScalarMetricConfidenceValue ]
58+ metric_name : Optional [str ]
59+ aggregation : ScalarMetricAggregation = ScalarMetricAggregation .ARITHMETIC_MEAN
1560
1661 def to_common (self ) -> ScalarMetric :
17- return ScalarMetric (
18- value = self .metric_value ,
19- metric_name = self .metric_name ,
20- feature_name = self .feature_name ,
21- subclass_name = self .subclass_name ,
22- aggregation = ScalarMetricAggregation [self .aggregation ],
23- extra = {'uuid' : self .uuid })
62+ return ScalarMetric (value = self .metric_value ,
63+ metric_name = self .metric_name ,
64+ feature_name = self .feature_name ,
65+ subclass_name = self .subclass_name ,
66+ aggregation = self .aggregation ,
67+ extra = {'uuid' : self .uuid })
2468
2569 @classmethod
2670 def from_common (cls , metric : ScalarMetric ,
@@ -35,38 +79,39 @@ def from_common(cls, metric: ScalarMetric,
3579
3680 def dict (self , * args , ** kwargs ):
3781 res = super ().dict (* args , ** kwargs )
38- for field in ['featureName' , 'subclassName' ]:
39- if res [field ] is None :
40- res .pop (field )
41-
4282 # For backwards compatibility.
4383 if res ['metricName' ] is None :
4484 res .pop ('metricName' )
4585 res .pop ('aggregation' )
4686 return res
4787
48- class Config :
49- use_enum_values = True
50-
5188
5289class NDMetricAnnotation :
5390
5491 @classmethod
55- def to_common (cls , annotation : "NDScalarMetric" ) -> ScalarMetric :
92+ def to_common (
93+ cls , annotation : Union [NDScalarMetric , NDConfusionMatrixMetric ]
94+ ) -> Union [ScalarMetric , ConfusionMatrixMetric ]:
5695 return annotation .to_common ()
5796
5897 @classmethod
59- def from_common (cls , annotation : ScalarMetric ,
60- data : Union [TextData , ImageData ]) -> "NDScalarMetric" :
98+ def from_common (
99+ cls , annotation : Union [ScalarMetric ,
100+ ConfusionMatrixMetric ], data : Union [TextData ,
101+ ImageData ]
102+ ) -> Union [NDScalarMetric , NDConfusionMatrixMetric ]:
61103 obj = cls .lookup_object (annotation )
62104 return obj .from_common (annotation , data )
63105
64106 @staticmethod
65- def lookup_object (metric : ScalarMetric ) -> "NDScalarMetric" :
107+ def lookup_object (
108+ annotation : Union [ScalarMetric , ConfusionMatrixMetric ]
109+ ) -> Union [Type [NDScalarMetric ], Type [NDConfusionMatrixMetric ]]:
66110 result = {
67111 ScalarMetric : NDScalarMetric ,
68- }.get (type (metric ))
112+ ConfusionMatrixMetric : NDConfusionMatrixMetric ,
113+ }.get (type (annotation ))
69114 if result is None :
70115 raise TypeError (
71- f"Unable to convert object to MAL format. `{ type (metric )} `" )
116+ f"Unable to convert object to MAL format. `{ type (annotation )} `" )
72117 return result
0 commit comments