|
|
- """Tests for elpy.rpc."""
-
- import json
- import unittest
- import sys
-
- from elpy import rpc
- from elpy.tests.compat import StringIO
-
-
- class TestFault(unittest.TestCase):
- def test_should_have_code_and_data(self):
- fault = rpc.Fault("Hello", code=250, data="Fnord")
- self.assertEqual(str(fault), "Hello")
- self.assertEqual(fault.code, 250)
- self.assertEqual(fault.data, "Fnord")
-
- def test_should_have_defaults_for_code_and_data(self):
- fault = rpc.Fault("Hello")
- self.assertEqual(str(fault), "Hello")
- self.assertEqual(fault.code, 500)
- self.assertIsNone(fault.data)
-
-
- class TestJSONRPCServer(unittest.TestCase):
- def setUp(self):
- self.stdin = StringIO()
- self.stdout = StringIO()
- self.rpc = rpc.JSONRPCServer(self.stdin, self.stdout)
-
- def write(self, s):
- self.stdin.seek(0)
- self.stdin.truncate()
- self.stdout.seek(0)
- self.stdout.truncate()
- self.stdin.write(s)
- self.stdin.seek(0)
-
- def read(self):
- value = self.stdout.getvalue()
- self.stdin.seek(0)
- self.stdin.truncate()
- self.stdout.seek(0)
- self.stdout.truncate()
- return value
-
-
- class TestInit(TestJSONRPCServer):
- def test_should_use_arguments(self):
- self.assertEqual(self.rpc.stdin, self.stdin)
- self.assertEqual(self.rpc.stdout, self.stdout)
-
- def test_should_default_to_sys(self):
- testrpc = rpc.JSONRPCServer()
- self.assertEqual(sys.stdin, testrpc.stdin)
- self.assertEqual(sys.stdout, testrpc.stdout)
-
-
- class TestReadJson(TestJSONRPCServer):
- def test_should_read_json(self):
- objlist = [{'foo': 'bar'},
- {'baz': 'qux', 'fnord': 'argl\nbargl'},
- "beep\r\nbeep\r\nbeep"]
- self.write("".join([(json.dumps(obj) + "\n")
- for obj in objlist]))
- for obj in objlist:
- self.assertEqual(self.rpc.read_json(),
- obj)
-
- def test_should_raise_eof_on_eof(self):
- self.assertRaises(EOFError, self.rpc.read_json)
-
- def test_should_fail_on_malformed_json(self):
- self.write("malformed json\n")
- self.assertRaises(ValueError,
- self.rpc.read_json)
-
-
- class TestWriteJson(TestJSONRPCServer):
- def test_should_write_json_line(self):
- objlist = [{'foo': 'bar'},
- {'baz': 'qux', 'fnord': 'argl\nbargl'},
- ]
- for obj in objlist:
- self.rpc.write_json(**obj)
- self.assertEqual(json.loads(self.read()),
- obj)
-
-
- class TestHandleRequest(TestJSONRPCServer):
- def test_should_fail_if_json_does_not_contain_a_method(self):
- self.write(json.dumps(dict(params=[],
- id=23)))
- self.assertRaises(ValueError,
- self.rpc.handle_request)
-
- def test_should_call_right_method(self):
- self.write(json.dumps(dict(method='foo',
- params=[1, 2, 3],
- id=23)))
- self.rpc.rpc_foo = lambda *params: params
- self.rpc.handle_request()
- self.assertEqual(json.loads(self.read()),
- dict(id=23,
- result=[1, 2, 3]))
-
- def test_should_pass_defaults_for_missing_parameters(self):
- def test_method(*params):
- self.args = params
-
- self.write(json.dumps(dict(method='foo')))
- self.rpc.rpc_foo = test_method
- self.rpc.handle_request()
- self.assertEqual(self.args, ())
- self.assertEqual(self.read(), "")
-
- def test_should_return_error_for_missing_method(self):
- self.write(json.dumps(dict(method='foo',
- id=23)))
- self.rpc.handle_request()
- result = json.loads(self.read())
-
- self.assertEqual(result["id"], 23)
- self.assertEqual(result["error"]["message"],
- "Unknown method foo")
-
- def test_should_return_error_for_exception_in_method(self):
- def test_method():
- raise ValueError("An error was raised")
-
- self.write(json.dumps(dict(method='foo',
- id=23)))
- self.rpc.rpc_foo = test_method
-
- self.rpc.handle_request()
- result = json.loads(self.read())
-
- self.assertEqual(result["id"], 23)
- self.assertEqual(result["error"]["message"], "An error was raised")
- self.assertIn("traceback", result["error"]["data"])
-
- def test_should_not_include_traceback_for_faults(self):
- def test_method():
- raise rpc.Fault("This is a fault")
-
- self.write(json.dumps(dict(method="foo",
- id=23)))
- self.rpc.rpc_foo = test_method
-
- self.rpc.handle_request()
- result = json.loads(self.read())
-
- self.assertEqual(result["id"], 23)
- self.assertEqual(result["error"]["message"], "This is a fault")
- self.assertNotIn("traceback", result["error"])
-
- def test_should_add_data_for_faults(self):
- def test_method():
- raise rpc.Fault("St. Andreas' Fault",
- code=12345, data="Yippieh")
-
- self.write(json.dumps(dict(method="foo", id=23)))
- self.rpc.rpc_foo = test_method
-
- self.rpc.handle_request()
- result = json.loads(self.read())
-
- self.assertEqual(result["error"]["data"], "Yippieh")
-
- def test_should_call_handle_for_unknown_method(self):
- def test_handle(method_name, args):
- return "It works"
- self.write(json.dumps(dict(method="doesnotexist",
- id=23)))
- self.rpc.handle = test_handle
- self.rpc.handle_request()
- self.assertEqual(json.loads(self.read()),
- dict(id=23,
- result="It works"))
-
-
- class TestServeForever(TestJSONRPCServer):
- def handle_request(self):
- self.hr_called += 1
- if self.hr_called > 10:
- raise self.error()
-
- def setUp(self):
- super(TestServeForever, self).setUp()
- self.hr_called = 0
- self.error = KeyboardInterrupt
- self.rpc.handle_request = self.handle_request
-
- def test_should_call_handle_request_repeatedly(self):
- self.rpc.serve_forever()
- self.assertEqual(self.hr_called, 11)
-
- def test_should_return_on_some_errors(self):
- self.error = KeyboardInterrupt
- self.rpc.serve_forever()
- self.error = EOFError
- self.rpc.serve_forever()
- self.error = SystemExit
- self.rpc.serve_forever()
-
- def test_should_fail_on_most_errors(self):
- self.error = RuntimeError
- self.assertRaises(RuntimeError,
- self.rpc.serve_forever)
|