diff --git a/app/controllers/client.py b/app/controllers/client.py index e9ab58337559fa03dad0eaba83d7ecc11a33f39d..e7a081b2e962b6190428f404c7ee4664b223bde6 100644 --- a/app/controllers/client.py +++ b/app/controllers/client.py @@ -1,6 +1,8 @@ """Controllers for managing clients.""" from werkzeug.exceptions import BadRequest + from app.controllers.base import RestController +from app.models.project import Project from app.models.client import Client from app.forms.client import ClientForm @@ -9,6 +11,15 @@ class ClientController(RestController): """The client controller.""" Model = Client constant_fields = ['sales_force_id'] + filters = { + 'project_id': lambda d: Project.id == d['project_id']} + + def query(self, filter_data): + """Construct a query with additional joins based on filter_data.""" + q = super(ClientController, self).query(filter_data) + if 'project_id' in filter_data: + q = q.join(Project, Project.client_id == Client.id) + return q def get_form(self, filter_data): """Return the client form.""" diff --git a/app/tests/test_client.py b/app/tests/test_client.py index 0dca40928934543b353fbe29bf0f06f96325fca0..8cd245154b56f1450a59887c411da07eb3cf52ca 100644 --- a/app/tests/test_client.py +++ b/app/tests/test_client.py @@ -15,6 +15,20 @@ class TestClient(RestTestCase): model = self.env.client self._test_index() + def test_index_project_id(self): + """Tests /client/?project_id=... GET.""" + # Construct a model that should be returned. + project = self.env.project + model = db.session.query(Client)\ + .filter(Client.id == project.client_id)\ + .first() + # Construct a model that should not be returned. + _ = self.env.client + + response_data = self._test_index({'project_id': project.id}) + self.assertEqual(len(response_data), 1) + self.assertEqual(response_data[0]['id'], model.id) + def test_get(self): """Tests /client/ GET.""" model = self.env.client