Skip to content
Snippets Groups Projects
Commit d57f67d3 authored by florimondmanca's avatar florimondmanca
Browse files

update visit participant API

parent 8b02fd40
No related branches found
No related tags found
No related merge requests found
...@@ -19,16 +19,18 @@ urlpatterns = [ ...@@ -19,16 +19,18 @@ urlpatterns = [
url(r'^auth/get-token/$', obtain_auth_token, name='get-auth-token'), url(r'^auth/get-token/$', obtain_auth_token, name='get-auth-token'),
] ]
router = routers.DefaultRouter() router = routers.SimpleRouter()
# Visits views # Visits views
router.register('visits', visits_views.VisitViewSet) router.register('visits', visits_views.VisitViewSet)
router.register('visit-participants', visits_views.VisitParticipantViewSet) router.register('visit-participants', visits_views.VisitParticipantsViewSet,
base_name='visit-participants')
# Users views # Users views
router.register(r'users', users_views.UserViewSet) router.register(r'users', users_views.UserViewSet)
router.register(r'tutors', users_views.TutorViewSet) router.register(r'tutors', users_views.TutorViewSet)
router.register(r'students', users_views.StudentViewSet) router.register(r'students', users_views.StudentViewSet)
# router.register('student-visits', users_views.StudentVisitsViewSet)
router.register(r'schoolstaffmembers', users_views.SchoolStaffMemberViewSet) router.register(r'schoolstaffmembers', users_views.SchoolStaffMemberViewSet)
# Tutoring views # Tutoring views
......
...@@ -37,15 +37,3 @@ class VisitEndpointsTest(HyperlinkedAPITestCase): ...@@ -37,15 +37,3 @@ class VisitEndpointsTest(HyperlinkedAPITestCase):
url = '/api/visits/{obj.pk}/participants/'.format(obj=obj) url = '/api/visits/{obj.pk}/participants/'.format(obj=obj)
response = self.client.get(url) response = self.client.get(url)
return response return response
def test_participants_contains_participants_of_visit(self):
self.client.force_login(UserFactory.create())
obj = self.factory.create()
response = self.perform_list_participants(obj=obj)
num_participants = obj.participants.all().count()
self.assertEqual(len(response.data), num_participants)
def test_participants_authentication_required(self):
self.assertRequiresAuth(
self.perform_list_participants,
expected_status_code=status.HTTP_200_OK)
...@@ -3,9 +3,10 @@ ...@@ -3,9 +3,10 @@
from django.test import TestCase from django.test import TestCase
from rest_framework import status from rest_framework import status
from tests.factory import VisitParticipantFactory from tests.factory import VisitParticipantFactory
from tests.factory import VisitFactory, StudentFactory from tests.factory import VisitFactory, StudentFactory, UserFactory
from tests.utils import HyperlinkedAPITestCase from tests.utils import HyperlinkedAPITestCase
from visits.serializers import VisitParticipantWriteSerializer from visits.serializers import VisitParticipantWriteSerializer
from visits.models import VisitParticipant
class VisitParticipantEndpointsTest(HyperlinkedAPITestCase): class VisitParticipantEndpointsTest(HyperlinkedAPITestCase):
...@@ -22,18 +23,19 @@ class VisitParticipantEndpointsTest(HyperlinkedAPITestCase): ...@@ -22,18 +23,19 @@ class VisitParticipantEndpointsTest(HyperlinkedAPITestCase):
VisitFactory.create_batch(10) VisitFactory.create_batch(10)
cls.factory.create_batch(5) cls.factory.create_batch(5)
def perform_list(self): # def perform_list(self):
url = '/api/visit-participants/' # url = '/api/visit-participants/'
response = self.client.get(url) # response = self.client.get(url)
return response # return response
#
def test_list_authentication_required(self): # def test_list_authentication_required(self):
self.assertRequiresAuth( # self.assertRequiresAuth(
self.perform_list, expected_status_code=status.HTTP_200_OK) # self.perform_list, expected_status_code=status.HTTP_200_OK)
def perform_retrieve(self): def perform_retrieve(self, obj=None):
if obj is None:
obj = self.factory.create() obj = self.factory.create()
url = '/api/visit-participants/{obj.pk}/'.format(obj=obj) url = '/api/visit-participants/{obj.visit.pk}/'.format(obj=obj)
response = self.client.get(url) response = self.client.get(url)
return response return response
...@@ -41,9 +43,16 @@ class VisitParticipantEndpointsTest(HyperlinkedAPITestCase): ...@@ -41,9 +43,16 @@ class VisitParticipantEndpointsTest(HyperlinkedAPITestCase):
self.assertRequiresAuth( self.assertRequiresAuth(
self.perform_retrieve, expected_status_code=status.HTTP_200_OK) self.perform_retrieve, expected_status_code=status.HTTP_200_OK)
def test_retrieve_returns_participants_of_visit(self):
obj = self.factory.create()
response = self.perform_retrieve(obj=obj)
self.assertEqual(
len(response.json()),
VisitParticipant.objects.filter(visit=obj.visit).count())
def perform_create(self): def perform_create(self):
url = '/api/visit-participants/'
obj = self.factory.build() obj = self.factory.build()
url = '/api/visit-participants/'
data = self.serialize(obj, 'post', url) data = self.serialize(obj, 'post', url)
response = self.client.post(url, data, format='json') response = self.client.post(url, data, format='json')
return response return response
......
...@@ -50,7 +50,7 @@ class VisitParticipantTest(ModelTestCase): ...@@ -50,7 +50,7 @@ class VisitParticipantTest(ModelTestCase):
def test_get_absolute_url(self): def test_get_absolute_url(self):
self.client.force_login(UserFactory.create()) self.client.force_login(UserFactory.create())
url = self.obj.get_absolute_url() url = self.obj.get_absolute_url()
expected = '/api/visit-participants/{}/'.format(self.obj.pk) expected = '/api/visit-participants/{}/'.format(self.obj.visit.pk)
self.assertEqual(url, expected) self.assertEqual(url, expected)
response = self.client.get(url) response = self.client.get(url)
self.assertEqual(200, response.status_code) self.assertEqual(200, response.status_code)
...@@ -54,7 +54,7 @@ class VisitParticipant(models.Model): ...@@ -54,7 +54,7 @@ class VisitParticipant(models.Model):
unique_together = (('student', 'visit'),) unique_together = (('student', 'visit'),)
def get_absolute_url(self): def get_absolute_url(self):
return reverse('api:visitparticipant-detail', args=[str(self.pk)]) return reverse('api:visit-participants-detail', args=[str(self.pk)])
# Permissions # Permissions
......
...@@ -45,11 +45,11 @@ class VisitParticipantReadSerializer(serializers.HyperlinkedModelSerializer): ...@@ -45,11 +45,11 @@ class VisitParticipantReadSerializer(serializers.HyperlinkedModelSerializer):
model = VisitParticipant model = VisitParticipant
fields = ('id', 'url', 'student', 'visit', 'present') fields = ('id', 'url', 'student', 'visit', 'present')
extra_kwargs = { extra_kwargs = {
'url': {'view_name': 'api:visitparticipant-detail'} 'url': {'view_name': 'api:visit-participants-detail'},
} }
class VisitParticipantWriteSerializer(serializers.HyperlinkedModelSerializer): class VisitParticipantWriteSerializer(serializers.ModelSerializer):
"""Writable serializer for visit participants.""" """Writable serializer for visit participants."""
student_id = serializers.PrimaryKeyRelatedField( student_id = serializers.PrimaryKeyRelatedField(
...@@ -86,3 +86,13 @@ class VisitParticipantDetailSerializer(serializers.ModelSerializer): ...@@ -86,3 +86,13 @@ class VisitParticipantDetailSerializer(serializers.ModelSerializer):
model = VisitParticipant model = VisitParticipant
fields = ('student_id', 'first_name', 'last_name', fields = ('student_id', 'first_name', 'last_name',
'phone_number', 'email', 'present',) 'phone_number', 'email', 'present',)
class VisitParticipantIdentifySerializer(serializers.ModelSerializer):
student_id = serializers.IntegerField()
visit_id = serializers.IntegerField()
class Meta: # noqa
model = VisitParticipant
fields = ('student_id', 'visit_id',)
"""Visits API views.""" """Visits API views."""
from django.shortcuts import get_object_or_404
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework import status from rest_framework import status
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.decorators import detail_route from rest_framework import mixins
from rest_framework.decorators import list_route
from dry_rest_permissions.generics import DRYPermissions from dry_rest_permissions.generics import DRYPermissions
from .serializers import VisitSerializer from .serializers import VisitSerializer
from .serializers import VisitParticipantReadSerializer from .serializers import VisitParticipantReadSerializer
from .serializers import VisitParticipantWriteSerializer from .serializers import VisitParticipantWriteSerializer
from .serializers import VisitParticipantIdentifySerializer
from .serializers import VisitParticipantDetailSerializer from .serializers import VisitParticipantDetailSerializer
from .models import Visit, VisitParticipant from .models import Visit, VisitParticipant
from users.models import Student
class VisitViewSet(viewsets.ReadOnlyModelViewSet): class VisitViewSet(viewsets.ReadOnlyModelViewSet):
...@@ -19,23 +23,57 @@ class VisitViewSet(viewsets.ReadOnlyModelViewSet): ...@@ -19,23 +23,57 @@ class VisitViewSet(viewsets.ReadOnlyModelViewSet):
queryset = Visit.objects.all() queryset = Visit.objects.all()
permission_classes = (DRYPermissions,) permission_classes = (DRYPermissions,)
@detail_route()
def participants(self, request, pk=None):
"""List participants of a visit with their contact information."""
visit = self.get_object()
participants = VisitParticipant.objects.filter(visit=visit)
serializer = VisitParticipantDetailSerializer(participants, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
class VisitParticipantsViewSet(mixins.CreateModelMixin,
class VisitParticipantViewSet(viewsets.ModelViewSet): mixins.ListModelMixin,
mixins.DestroyModelMixin,
viewsets.GenericViewSet):
"""API endpoints to manage participants of visits.""" """API endpoints to manage participants of visits."""
queryset = VisitParticipant.objects.all()
permission_classes = (DRYPermissions,) permission_classes = (DRYPermissions,)
def get_queryset(self):
if self.action in ['retrieve']:
return Visit.objects.all()
return VisitParticipant.objects.all()
def get_serializer_class(self): def get_serializer_class(self):
if self.action in ('list', 'retrieve'): if self.action == 'list':
return VisitParticipantReadSerializer return VisitParticipantReadSerializer
elif self.action == 'retrieve':
return VisitParticipantDetailSerializer
elif self.action == 'get_id':
return VisitParticipantIdentifySerializer
else: else:
return VisitParticipantWriteSerializer return VisitParticipantWriteSerializer
def retrieve(self, request, pk=None):
"""Retrieve the participants to a visit."""
visit = self.get_object()
participants = VisitParticipant.objects.filter(visit=visit)
serializer = self.get_serializer(participants, many=True,
context={'request': request})
return Response(serializer.data, status=status.HTTP_200_OK)
@list_route(methods=['put'])
def get_id(self, request):
"""Special endpoint to get ID of participant from student and visit.
Useful to perform a DELETE request afterwards (which only accepts
a participant ID).
"""
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
student = get_object_or_404(
Student,
pk=serializer.validated_data['student_id'])
visit = get_object_or_404(
Visit,
pk=serializer.validated_data['visit_id'])
participant = get_object_or_404(VisitParticipant,
student=student,
visit=visit)
return Response({'id': participant.id}, status=status.HTTP_200_OK)
else:
return Response(serializer.errors,
status=status.HTTP_400_BAD_REQUEST)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment