Improve from_dict performance by caching fields

This commit is contained in:
Ivan Habunek 2023-11-26 09:16:21 +01:00
parent 48d9caef05
commit 1c5abb8419
No known key found for this signature in database
GPG key ID: F5F0623FF5EBCB3D

View file

@ -12,7 +12,8 @@ import dataclasses
from dataclasses import dataclass, is_dataclass from dataclasses import dataclass, is_dataclass
from datetime import date, datetime from datetime import date, datetime
from typing import Dict, List, Optional, Type, TypeVar, Union from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import get_type_hints from typing import get_type_hints
from toot.typing_compat import get_args, get_origin from toot.typing_compat import get_args, get_origin
@ -435,17 +436,27 @@ def from_dict(cls: Type[T], data: Dict) -> T:
data = prepare(data) data = prepare(data)
def _fields(): def _fields():
hints = get_type_hints(cls) for name, type, default in get_fields(cls):
for field in dataclasses.fields(cls): value = data.get(name, default)
field_type = _prune_optional(hints[field.name]) converted = _convert_with_error_handling(cls, name, type, value)
default_value = _get_default_value(field) yield name, converted
value = data.get(field.name, default_value)
converted = _convert_with_error_handling(cls, field.name, field_type, value)
yield field.name, converted
return cls(**dict(_fields())) return cls(**dict(_fields()))
@lru_cache(maxsize=100)
def get_fields(cls: Type) -> List[Tuple[str, Type, Any]]:
hints = get_type_hints(cls)
return [
(
field.name,
_prune_optional(hints[field.name]),
_get_default_value(field)
)
for field in dataclasses.fields(cls)
]
def from_dict_list(cls: Type[T], data: List[Dict]) -> List[T]: def from_dict_list(cls: Type[T], data: List[Dict]) -> List[T]:
return [from_dict(cls, x) for x in data] return [from_dict(cls, x) for x in data]
@ -497,7 +508,7 @@ def _convert(field_type, value):
raise ValueError(f"Not implemented for type '{field_type}'") raise ValueError(f"Not implemented for type '{field_type}'")
def _prune_optional(field_type): def _prune_optional(field_type: Type) -> Type:
"""For `Optional[<type>]` returns the encapsulated `<type>`.""" """For `Optional[<type>]` returns the encapsulated `<type>`."""
if get_origin(field_type) == Union: if get_origin(field_type) == Union:
args = get_args(field_type) args = get_args(field_type)