diff --git a/app/controllers/place.py b/app/controllers/place.py index 12a6b6c80ac1aa3ea7d99ab070c7bd45b5f29e8a..b870f7cfbdcfec4534f3a50ceb48379e015296b9 100644 --- a/app/controllers/place.py +++ b/app/controllers/place.py @@ -2,6 +2,7 @@ from werkzeug.exceptions import BadRequest from app.controllers.base import RestController from app.forms.place import AddressForm, PlaceForm +from app.models.project import Project from app.models.place import Address, Place @@ -11,20 +12,25 @@ class PlaceController(RestController): constant_fields = ['sales_force_id', 'address_id'] def get_form(self, filter_data): - """Return the Place form. - Args: - filter_data (dict) - """ + """Return the Place form.""" return PlaceForm class AddressController(RestController): """The Address controller.""" Model = Address + filters = { + 'project_id': lambda d: Project.id == d['project_id']} + + def query(self, filter_data): + """Construct a query with additional joins based on filter_data.""" + q = super(AddressController, self).query(filter_data) + if 'project_id' in filter_data: + q = q\ + .join(Place, Place.address_id == Address.id)\ + .join(Project, Project.place_id == Place.id) + return q def get_form(self, filter_data): - """Return the Address form. - Args: - filter_data (dict) - """ + """Return the Address form.""" return AddressForm diff --git a/app/controllers/project.py b/app/controllers/project.py index 878753e2672827b271efed6969788e59e147e13e..cd2b798486848cd55b7cd55b716d341d8abfd973 100644 --- a/app/controllers/project.py +++ b/app/controllers/project.py @@ -27,8 +27,7 @@ class ProjectController(RestController): 'q': lambda d: and_(*[ Project.name.like('%{}%'.format(term)) for term in d['q'].split(' ') - ]) - } + ])} def get_form(self, filter_data): """Return the project form.""" diff --git a/app/tests/test_place.py b/app/tests/test_place.py index 5aba1bc40dc4fd065418063bc9cdce4a8c6444e7..a2c12d40fdfebde7691832ee621e5f6735388199 100644 --- a/app/tests/test_place.py +++ b/app/tests/test_place.py @@ -2,6 +2,7 @@ from sqlalchemy import and_ from app.lib.database import db from app.tests.base import RestTestCase +from app.models.project import Project from app.models.place import Place, Address @@ -103,6 +104,22 @@ class TestAddress(RestTestCase): model = self.env.address self._test_index() + def test_index_project_id(self): + """Tests /address/?project_id=... GET.""" + # Build a model that should be in the response. + project = self.env.project + model = db.session.query(Address)\ + .join(Place, Place.address_id == Address.id)\ + .join(Project, Project.place_id == Place.id)\ + .filter(Project.id == project.id)\ + .first() + # Build a model that should not be in the response. + _ = self.env.address + + response_data = self._test_index({'project_id': project.id}) + self.assertEquals(len(response_data), 1) + self.assertEquals(response_data[0]['id'], model.id) + def test_get(self): """Tests /address/ GET.""" model = self.env.address