|
1 | | -from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation |
2 | | -from typing import Any, Dict, Optional |
3 | | -from pydantic import BaseModel |
| 1 | +from typing import Dict, Optional, Union |
| 2 | +from enum import Enum |
4 | 3 |
|
| 4 | +from pydantic import confloat |
5 | 5 |
|
6 | | -class ScalarMetric(BaseModel): |
7 | | - """ Class representing metrics |
| 6 | +from .base import ConfidenceValue, BaseMetric |
8 | 7 |
|
9 | | - # For backwards compatibility, metric_name is optional. This will eventually be deprecated |
10 | | - # The metric_name will be set to a default name in the editor if it is not set. |
| 8 | +ScalarMetricValue = confloat(ge=0, le=10_000) |
| 9 | +ScalarMetricConfidenceValue = Dict[ConfidenceValue, ScalarMetricValue] |
11 | 10 |
|
12 | | - # aggregation will be ignored wihtout providing a metric name. |
13 | | - # Not providing a metric name is deprecated. |
| 11 | + |
| 12 | +class ScalarMetricAggregation(Enum): |
| 13 | + ARITHMETIC_MEAN = "ARITHMETIC_MEAN" |
| 14 | + GEOMETRIC_MEAN = "GEOMETRIC_MEAN" |
| 15 | + HARMONIC_MEAN = "HARMONIC_MEAN" |
| 16 | + SUM = "SUM" |
| 17 | + |
| 18 | + |
| 19 | +class ScalarMetric(BaseMetric): |
| 20 | + """ Class representing scalar metrics |
| 21 | +
|
| 22 | + For backwards compatibility, metric_name is optional. |
| 23 | + The metric_name will be set to a default name in the editor if it is not set. |
| 24 | + This is not recommended and support for empty metric_name fields will be removed. |
| 25 | + aggregation will be ignored wihtout providing a metric name. |
14 | 26 | """ |
15 | | - value: float |
16 | 27 | metric_name: Optional[str] = None |
17 | | - feature_name: Optional[str] = None |
18 | | - subclass_name: Optional[str] = None |
19 | | - aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN |
20 | | - extra: Dict[str, Any] = {} |
| 28 | + value: Union[ScalarMetricValue, ScalarMetricConfidenceValue] |
| 29 | + aggregation: ScalarMetricAggregation = ScalarMetricAggregation.ARITHMETIC_MEAN |
21 | 30 |
|
22 | 31 | def dict(self, *args, **kwargs): |
23 | 32 | res = super().dict(*args, **kwargs) |
24 | | - if res['metric_name'] is None: |
| 33 | + if res.get('metric_name') is None: |
25 | 34 | res.pop('aggregation') |
26 | | - return {k: v for k, v in res.items() if v is not None} |
| 35 | + return res |
0 commit comments