1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 import os
38 import gc
39 import logging
40 import sys
41 import contextlib
42
43
44 try:
45 import unittest2 as unittest
46 except ImportError:
47 import unittest
48
49 import pike.model as model
50 import pike.smb2 as smb2
54 init_done = False
55
56 @staticmethod
57 - def option(name, default=None):
58 if name in os.environ:
59 value = os.environ[name]
60 if len(value) == 0:
61 value = default
62 else:
63 value = default
64
65 return value
66
67 @staticmethod
69 table = {'yes': True, 'no': False, '': False}
70 return table[PikeTest.option(name, 'no')]
71
72 @staticmethod
75
76 @staticmethod
89
91 unittest.TestCase.__init__(self, *args, **kwargs)
92 self.init_once()
93 self.server = self.option('PIKE_SERVER')
94 self.port = int(self.option('PIKE_PORT', '445'))
95 self.creds = self.option('PIKE_CREDS')
96 self.share = self.option('PIKE_SHARE', 'c$')
97 self.signing = self.booloption('PIKE_SIGN')
98 self.encryption = self.booloption('PIKE_ENCRYPT')
99 self.min_dialect = self.smb2constoption('PIKE_MIN_DIALECT')
100 self.max_dialect = self.smb2constoption('PIKE_MAX_DIALECT')
101 self._connections = []
102 self.default_client = model.Client()
103 if self.min_dialect is not None:
104 self.default_client.dialects = filter(
105 lambda d: d >= self.min_dialect,
106 self.default_client.dialects)
107 if self.max_dialect is not None:
108 self.default_client.dialects = filter(
109 lambda d: d <= self.max_dialect,
110 self.default_client.dialects)
111 if self.signing:
112 self.default_client.security_mode = (smb2.SMB2_NEGOTIATE_SIGNING_ENABLED |
113 smb2.SMB2_NEGOTIATE_SIGNING_REQUIRED)
114
115 - def debug(self, *args, **kwargs):
116 self.logger.debug(*args, **kwargs)
117
118 - def info(self, *args, **kwargs):
119 self.logger.info(*args, **kwargs)
120
121 - def warn(self, *args, **kwargs):
122 self.logger.warn(*args, **kwargs)
123
124 - def error(self, *args, **kwargs):
125 self.logger.error(*args, **kwargs)
126
128 self.logger.critical(*args, **kwargs)
129
131 dialect_range = self.required_dialect()
132 req_caps = self.required_capabilities()
133 req_share_caps = self.required_share_capabilities()
134
135 if client is None:
136 client = self.default_client
137
138 conn = client.connect(self.server, self.port).negotiate()
139
140 if (conn.negotiate_response.dialect_revision < dialect_range[0] or
141 conn.negotiate_response.dialect_revision > dialect_range[1]):
142 self.skipTest("Dialect required: %s" % str(dialect_range))
143
144 if conn.negotiate_response.capabilities & req_caps != req_caps:
145 self.skipTest("Capabilities missing: %s " %
146 str(req_caps & ~conn.negotiate_response.capabilities))
147
148 chan = conn.session_setup(self.creds, resume=resume)
149 if self.encryption:
150 chan.session.encrypt_data = True
151
152 tree = chan.tree_connect(self.share)
153
154 if tree.tree_connect_response.capabilities & req_share_caps != req_share_caps:
155 self.skipTest("Share capabilities missing: %s" %
156 str(req_share_caps & ~tree.tree_connect_response.capabilities))
157 self._connections.append(conn)
158 return (chan,tree)
159
160 - class _AssertErrorContext(object):
162
163 @contextlib.contextmanager
165 e = None
166 o = PikeTest._AssertErrorContext()
167
168 try:
169 yield o
170 except model.ResponseError as e:
171 pass
172
173 if e is None:
174 raise self.failureException('No error raised when "%s" expected' % status)
175 elif e.response.status != status:
176 raise self.failureException('"%s" raised when "%s" expected' % (e.response.status, status))
177
178 o.response = e.response
179
181 if self.loglevel != logging.NOTSET:
182 print >>sys.stderr
183
184 if hasattr(self, 'setup'):
185 self.setup()
186
188 if hasattr(self, 'teardown'):
189 self.teardown()
190
191 for conn in self._connections:
192 conn.close()
193 del self._connections[:]
194
195 gc.collect()
196
198 name = '__pike_test_' + name
199 test_method = getattr(self, self._testMethodName)
200
201 if hasattr(test_method, name):
202 return getattr(test_method, name)
203 elif hasattr(self.__class__, name):
204 return getattr(self.__class__, name)
205 else:
206 return default
207
210
213
216
218 """
219 Compare two sequences using a binary diff to efficiently determine
220 the first offset where they differ
221 """
222 if len(buf1) != len(buf2):
223 raise AssertionError("Buffers are not the same size")
224 low = 0
225 high = len(buf1)
226 while high - low > 1:
227 chunk_1 = (low, low+(high-low)/2)
228 chunk_2 = (low+(high-low)/2, high)
229 if buf1[chunk_1[0]:chunk_1[1]] != buf2[chunk_1[0]:chunk_1[1]]:
230 low, high = chunk_1
231 elif buf1[chunk_2[0]:chunk_2[1]] != buf2[chunk_2[0]:chunk_2[1]]:
232 low, high = chunk_2
233 else:
234 break
235 if high - low <= 1:
236 raise AssertionError("Block mismatch at byte {0}: "
237 "{1} != {2}".format(low, buf1[low], buf2[low]))
238
242
244 setattr(thing, '__pike_test_' + self.__class__.__name__, self.value)
245 return thing
246
249 - def __init__(self, minvalue=0, maxvalue=float('inf')):
250 self.minvalue = minvalue
251 self.maxvalue = maxvalue
252
254 setattr(thing, '__pike_test_' + self.__class__.__name__,
255 (self.minvalue, self.maxvalue))
256 return thing
257
262
265 """
266 Custom test suite for easily patching in skip tests in downstream
267 distributions of these test cases
268 """
269 skip_tests_reasons = {
270 "test_to_be_skipped": "This test should be skipped",
271 }
272
273 @staticmethod
275 def inner(*args, **kwds):
276 raise unittest.SkipTest(reason)
277 return inner
278
288
291 test_loader = unittest.TestLoader()
292 test_loader.suiteClass = PikeTestSuite
293 test_suite = test_loader.discover(
294 os.path.abspath(os.path.dirname(__file__)),
295 "*.py")
296 return test_suite
297
298 if __name__ == '__main__':
299 test_runner = unittest.TextTestRunner()
300 test_runner.run(suite())
301