|
1 | 1 | from labelbox.data.annotation_types.metrics.aggregations import MetricAggregation |
2 | | -from typing import Any, Dict, Optional |
3 | | -from pydantic import BaseModel |
| 2 | +from typing import Any, Dict, Optional, Tuple, Union |
| 3 | +from pydantic import BaseModel, Field, validator, confloat |
4 | 4 |
|
5 | 5 |
|
6 | | -class ScalarMetric(BaseModel): |
7 | | - """ Class representing metrics |
| 6 | +ScalarMetricConfidenceValue = Dict[confloat(ge=0, le=1), float] |
| 7 | +ConfusionMatrixMetricConfidenceValue = Dict[confloat(ge=0, le=1), Tuple[int,int,int,int]] |
8 | 8 |
|
9 | | - # For backwards compatibility, metric_name is optional. This will eventually be deprecated |
10 | | - # The metric_name will be set to a default name in the editor if it is not set. |
11 | 9 |
|
12 | | - # aggregation will be ignored wihtout providing a metric name. |
13 | | - # Not providing a metric name is deprecated. |
14 | | - """ |
15 | | - value: float |
| 10 | +class BaseMetric(BaseModel): |
16 | 11 | metric_name: Optional[str] = None |
17 | 12 | feature_name: Optional[str] = None |
18 | 13 | subclass_name: Optional[str] = None |
19 | | - aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN |
20 | 14 | extra: Dict[str, Any] = {} |
21 | 15 |
|
| 16 | + |
| 17 | +class ScalarMetric(BaseMetric): |
| 18 | + """ Class representing scalar metrics |
| 19 | +
|
| 20 | + For backwards compatibility, metric_name is optional. |
| 21 | + The metric_name will be set to a default name in the editor if it is not set. |
| 22 | + This is not recommended and support for empty metric_name fields will be removed. |
| 23 | + aggregation will be ignored wihtout providing a metric name. |
| 24 | + """ |
| 25 | + value: Union[float, ScalarMetricConfidenceValue] |
| 26 | + aggregation: MetricAggregation = MetricAggregation.ARITHMETIC_MEAN |
| 27 | + |
22 | 28 | def dict(self, *args, **kwargs): |
23 | 29 | res = super().dict(*args, **kwargs) |
24 | 30 | if res['metric_name'] is None: |
25 | 31 | res.pop('aggregation') |
26 | 32 | return {k: v for k, v in res.items() if v is not None} |
| 33 | + |
| 34 | + @validator('aggregation') |
| 35 | + def validate_aggregation(cls, aggregation): |
| 36 | + if aggregation == MetricAggregation.CONFUSION_MATRIX: |
| 37 | + raise ValueError("Cannot assign `MetricAggregation.CONFUSION_MATRIX` to `ScalarMetric.aggregation`") |
| 38 | + |
| 39 | + |
| 40 | + |
| 41 | +class ConfusionMatrixMetric(BaseMetric): |
| 42 | + """ Class representing confusion matrix metrics. |
| 43 | +
|
| 44 | + In the editor, this provides precision, recall, and f-scores. |
| 45 | + This should be used over multiple scalar metrics so that aggregations are accurate. |
| 46 | +
|
| 47 | + value should be a tuple representing: |
| 48 | + [True Positive Count, False Positive Count, True Negative Count, False Negative Count] |
| 49 | +
|
| 50 | + aggregation cannot be adjusted for confusion matrix metrics. |
| 51 | + """ |
| 52 | + value: Union[Tuple[int,int,int,int], ConfusionMatrixMetricConfidenceValue] |
| 53 | + aggregation: MetricAggregation = Field(MetricAggregation.CONFUSION_MATRIX, const = True) |
| 54 | + |
0 commit comments