Improve from_dict performance by caching fields
This commit is contained in:
parent
48d9caef05
commit
1c5abb8419
1 changed files with 20 additions and 9 deletions
|
@ -12,7 +12,8 @@ import dataclasses
|
||||||
|
|
||||||
from dataclasses import dataclass, is_dataclass
|
from dataclasses import dataclass, is_dataclass
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing import Dict, List, Optional, Type, TypeVar, Union
|
from functools import lru_cache
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
from typing import get_type_hints
|
from typing import get_type_hints
|
||||||
|
|
||||||
from toot.typing_compat import get_args, get_origin
|
from toot.typing_compat import get_args, get_origin
|
||||||
|
@ -435,17 +436,27 @@ def from_dict(cls: Type[T], data: Dict) -> T:
|
||||||
data = prepare(data)
|
data = prepare(data)
|
||||||
|
|
||||||
def _fields():
|
def _fields():
|
||||||
hints = get_type_hints(cls)
|
for name, type, default in get_fields(cls):
|
||||||
for field in dataclasses.fields(cls):
|
value = data.get(name, default)
|
||||||
field_type = _prune_optional(hints[field.name])
|
converted = _convert_with_error_handling(cls, name, type, value)
|
||||||
default_value = _get_default_value(field)
|
yield name, converted
|
||||||
value = data.get(field.name, default_value)
|
|
||||||
converted = _convert_with_error_handling(cls, field.name, field_type, value)
|
|
||||||
yield field.name, converted
|
|
||||||
|
|
||||||
return cls(**dict(_fields()))
|
return cls(**dict(_fields()))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=100)
|
||||||
|
def get_fields(cls: Type) -> List[Tuple[str, Type, Any]]:
|
||||||
|
hints = get_type_hints(cls)
|
||||||
|
return [
|
||||||
|
(
|
||||||
|
field.name,
|
||||||
|
_prune_optional(hints[field.name]),
|
||||||
|
_get_default_value(field)
|
||||||
|
)
|
||||||
|
for field in dataclasses.fields(cls)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def from_dict_list(cls: Type[T], data: List[Dict]) -> List[T]:
|
def from_dict_list(cls: Type[T], data: List[Dict]) -> List[T]:
|
||||||
return [from_dict(cls, x) for x in data]
|
return [from_dict(cls, x) for x in data]
|
||||||
|
|
||||||
|
@ -497,7 +508,7 @@ def _convert(field_type, value):
|
||||||
raise ValueError(f"Not implemented for type '{field_type}'")
|
raise ValueError(f"Not implemented for type '{field_type}'")
|
||||||
|
|
||||||
|
|
||||||
def _prune_optional(field_type):
|
def _prune_optional(field_type: Type) -> Type:
|
||||||
"""For `Optional[<type>]` returns the encapsulated `<type>`."""
|
"""For `Optional[<type>]` returns the encapsulated `<type>`."""
|
||||||
if get_origin(field_type) == Union:
|
if get_origin(field_type) == Union:
|
||||||
args = get_args(field_type)
|
args = get_args(field_type)
|
||||||
|
|
Loading…
Reference in a new issue