pygo/gomill/gomill_tests/gomill_test_support.py

162 lines
5.3 KiB
Python

"""Gomill-specific test support code."""
import re
from gomill import __version__
from gomill_tests.test_framework import unittest2
from gomill_tests import test_framework
from gomill.common import *
from gomill import ascii_boards
from gomill import boards
# This makes TestResult ignore lines from this module in tracebacks
__unittest = True
def compare_boards(b1, b2):
"""Check whether two boards have the same position.
returns a pair (position_is_the_same, message)
"""
if b1.side != b2.side:
raise ValueError("size is different: %s, %s" % (b1.side, b2.side))
differences = []
for row, col in b1.board_points:
if b1.get(row, col) != b2.get(row, col):
differences.append((row, col))
if not differences:
return True, None
msg = "boards differ at %s" % " ".join(map(format_vertex, differences))
try:
msg += "\n%s\n%s" % (
ascii_boards.render_board(b1), ascii_boards.render_board(b2))
except Exception:
pass
return False, msg
def compare_diagrams(d1, d2):
"""Compare two ascii board diagrams.
returns a pair (strings_are_equal, message)
(assertMultiLineEqual tends to look nasty for these, so we just show them
both in full)
"""
if d1 == d2:
return True, None
return False, "diagrams differ:\n%s\n\n%s" % (d1, d2)
def scrub_sgf(s):
"""Normalise sgf string for convenience of testing.
Replaces dates with '***', and 'gomill:<__version__>' with 'gomill:VER'.
Be careful: gomill version length can affect line wrapping. Either
serialise with wrap=None or remove newlines before comparing.
"""
s = re.sub(r"(?m)(?<=^Date ).*$", "***", s)
s = re.sub(r"(?<=DT\[)[-0-9]+(?=\])", "***", s)
s = re.sub(r"gomill:" + re.escape(__version__), "gomill:VER", s)
return s
traceback_line_re = re.compile(
r" .*/([a-z0-9_]+)\.pyc?:[0-9]+ \(([a-z0-9_]+)\)")
class Gomill_testcase_mixin(object):
"""TestCase mixin adding support for gomill-specific types.
This adds:
assertBoardEqual
assertEqual and assertNotEqual for Boards
"""
def init_gomill_testcase_mixin(self):
self.addTypeEqualityFunc(boards.Board, self.assertBoardEqual)
def _format_message(self, msg, standardMsg):
# This is the same as _formatMessage from python 2.7 unittest; copying
# it because it's not part of the public API.
if not self.longMessage:
return msg or standardMsg
if msg is None:
return standardMsg
try:
return '%s : %s' % (standardMsg, msg)
except UnicodeDecodeError:
return '%s : %s' % (unittest2.util.safe_repr(standardMsg),
unittest2.util.safe_repr(msg))
def assertBoardEqual(self, b1, b2, msg=None):
are_equal, desc = compare_boards(b1, b2)
if not are_equal:
self.fail(self._format_message(msg, desc+"\n"))
def assertDiagramEqual(self, d1, d2, msg=None):
are_equal, desc = compare_diagrams(d1, d2)
if not are_equal:
self.fail(self._format_message(msg, desc+"\n"))
def assertNotEqual(self, first, second, msg=None):
if isinstance(first, boards.Board) and isinstance(second, boards.Board):
are_equal, _ = compare_boards(first, second)
if not are_equal:
return
msg = self._format_message(msg, 'boards have the same position')
raise self.failureException(msg)
super(Gomill_testcase_mixin, self).assertNotEqual(first, second, msg)
def assertTracebackStringEqual(self, seen, expected, fixups=()):
"""Compare two strings which include tracebacks.
This is for comparing strings containing tracebacks from
the compact_tracebacks module.
Replaces the traceback lines describing source locations with
'<filename>|<functionname>', for robustness.
fixups -- list of pairs of strings
(additional substitutions to make in the 'seen' string)
"""
lines = seen.split("\n")
new_lines = []
for l in lines:
match = traceback_line_re.match(l)
if match:
l = "|".join(match.groups())
for a, b in fixups:
l = l.replace(a, b)
new_lines.append(l)
self.assertMultiLineEqual("\n".join(new_lines), expected)
class Gomill_SimpleTestCase(
Gomill_testcase_mixin, test_framework.SimpleTestCase):
"""SimpleTestCase with the Gomill mixin."""
def __init__(self, *args, **kwargs):
test_framework.SimpleTestCase.__init__(self, *args, **kwargs)
self.init_gomill_testcase_mixin()
class Gomill_ParameterisedTestCase(
Gomill_testcase_mixin, test_framework.ParameterisedTestCase):
"""ParameterisedTestCase with the Gomill mixin."""
def __init__(self, *args, **kwargs):
test_framework.ParameterisedTestCase.__init__(self, *args, **kwargs)
self.init_gomill_testcase_mixin()
def make_simple_tests(source, prefix="test_"):
"""Make test cases from a module's test_xxx functions.
See test_framework for details.
The test functions can use the Gomill_testcase_mixin enhancements.
"""
return test_framework.make_simple_tests(
source, prefix, testcase_class=Gomill_SimpleTestCase)