|
1 | 1 | """Support for django rest framework symmetric serialization"""
|
2 | 2 |
|
3 |
| -__all__ = ['EnumField'] |
| 3 | +__all__ = ['EnumField', 'EnumFieldMixin'] |
4 | 4 |
|
5 | 5 | try:
|
6 | 6 | from enum import Enum
|
7 | 7 | from typing import Any, Type, Union
|
8 | 8 |
|
9 | 9 | from django_enum.utils import choices, determine_primitive
|
| 10 | + from django_enum import EnumField as EnumModelField |
10 | 11 | from rest_framework.fields import ChoiceField
|
| 12 | + from rest_framework.serializers import ClassLookupDict |
| 13 | + from rest_framework.utils.field_mapping import get_field_kwargs |
| 14 | + |
11 | 15 |
|
12 | 16 | class EnumField(ChoiceField):
|
13 | 17 | """
|
@@ -37,7 +41,7 @@ def __init__(
|
37 | 41 | self.enum = enum
|
38 | 42 | self.primitive = determine_primitive(enum) # type: ignore
|
39 | 43 | assert self.primitive is not None, \
|
40 |
| - f'Unable to determine primitive type for {Enum}' |
| 44 | + f'Unable to determine primitive type for {enum}' |
41 | 45 | self.strict = strict
|
42 | 46 | self.choices = kwargs.pop('choices', choices(enum))
|
43 | 47 | super().__init__(choices=self.choices, **kwargs)
|
@@ -72,6 +76,54 @@ def to_representation( # pylint: disable=R0201
|
72 | 76 | return getattr(value, 'value', value)
|
73 | 77 |
|
74 | 78 |
|
| 79 | + class EnumFieldMixin: |
| 80 | + """ |
| 81 | + A mixin for ModelSerializers that adds auto-magic support for |
| 82 | + EnumFields. |
| 83 | + """ |
| 84 | + |
| 85 | + def build_standard_field(self, field_name, model_field): |
| 86 | + """ |
| 87 | + The default implementation of build_standard_field will set any |
| 88 | + field with choices to a ChoiceField. This will override that for |
| 89 | + EnumFields and add enum and strict arguments to the field's kwargs. |
| 90 | +
|
| 91 | + To use this mixin, include it before ModelSerializer in your |
| 92 | + serializer's class hierarchy: |
| 93 | +
|
| 94 | + ..code-block:: python |
| 95 | +
|
| 96 | + from django_enum.drf import EnumFieldMixin |
| 97 | + from rest_framework.serializers import ModelSerializer |
| 98 | +
|
| 99 | + class MySerializer(EnumFieldMixin, ModelSerializer): |
| 100 | +
|
| 101 | + class Meta: |
| 102 | + model = MyModel |
| 103 | + fields = '__all__' |
| 104 | +
|
| 105 | +
|
| 106 | + :param field_name: The name of the field on the serializer |
| 107 | + :param model_field: The Field instance on the model |
| 108 | + :return: A 2-tuple, the first element is the field class, the |
| 109 | + second is the kwargs for the field |
| 110 | + """ |
| 111 | + try: |
| 112 | + field_class = ClassLookupDict({EnumModelField: EnumField})[ |
| 113 | + model_field |
| 114 | + ] |
| 115 | + return field_class, { |
| 116 | + 'enum': model_field.enum, |
| 117 | + 'strict': model_field.strict, |
| 118 | + **super().build_standard_field( |
| 119 | + field_name, |
| 120 | + model_field |
| 121 | + )[1], |
| 122 | + } |
| 123 | + except KeyError: |
| 124 | + return super().build_standard_field(field_name, model_field) |
| 125 | + |
| 126 | + |
75 | 127 | except (ImportError, ModuleNotFoundError):
|
76 | 128 |
|
77 | 129 | class _MissingDjangoRestFramework:
|
|
0 commit comments