228 lines
7.7 KiB
Python
228 lines
7.7 KiB
Python
|
"""Tests for cem_tuners.py"""
|
||
|
|
||
|
from __future__ import with_statement, division
|
||
|
|
||
|
import cPickle as pickle
|
||
|
from textwrap import dedent
|
||
|
|
||
|
from gomill import cem_tuners
|
||
|
from gomill.game_jobs import Game_job, Game_job_result
|
||
|
from gomill.gtp_games import Game_result
|
||
|
from gomill.cem_tuners import Parameter_config
|
||
|
from gomill.competitions import (
|
||
|
Player_config, CompetitionError, ControlFileError)
|
||
|
|
||
|
from gomill_tests import gomill_test_support
|
||
|
|
||
|
def make_tests(suite):
|
||
|
suite.addTests(gomill_test_support.make_simple_tests(globals()))
|
||
|
|
||
|
def simple_make_candidate(*args):
|
||
|
if -1 in args:
|
||
|
raise ValueError("oops")
|
||
|
return Player_config("cand " + " ".join(map(str, args)))
|
||
|
|
||
|
def clip_axisb(f):
|
||
|
f = float(f)
|
||
|
return max(0.0, min(100.0, f))
|
||
|
|
||
|
def default_config():
|
||
|
return {
|
||
|
'board_size' : 13,
|
||
|
'komi' : 7.5,
|
||
|
'players' : {
|
||
|
'opp' : Player_config("test"),
|
||
|
},
|
||
|
'candidate_colour' : 'w',
|
||
|
'opponent' : 'opp',
|
||
|
'parameters' : [
|
||
|
Parameter_config(
|
||
|
'axisa',
|
||
|
initial_mean = 0.5,
|
||
|
initial_variance = 1.0,
|
||
|
format = "axa %.3f"),
|
||
|
Parameter_config(
|
||
|
'axisb',
|
||
|
initial_mean = 50.0,
|
||
|
initial_variance = 1000.0,
|
||
|
transform = clip_axisb,
|
||
|
format = "axb %.1f"),
|
||
|
],
|
||
|
'batch_size' : 3,
|
||
|
'samples_per_generation' : 4,
|
||
|
'number_of_generations' : 3,
|
||
|
'elite_proportion' : 0.1,
|
||
|
'step_size' : 0.8,
|
||
|
'make_candidate' : simple_make_candidate,
|
||
|
}
|
||
|
|
||
|
def test_parameter_config(tc):
|
||
|
comp = cem_tuners.Cem_tuner('cemtest')
|
||
|
config = default_config()
|
||
|
comp.initialise_from_control_file(config)
|
||
|
|
||
|
tc.assertEqual(comp.initial_distribution.format(),
|
||
|
" 0.50~1.00 50.00~1000.00")
|
||
|
tc.assertEqual(comp.format_engine_parameters((0.5, 23)),
|
||
|
"axa 0.500; axb 23.0")
|
||
|
tc.assertEqual(comp.format_engine_parameters(('x', 23)),
|
||
|
"[axisa?x]; axb 23.0")
|
||
|
tc.assertEqual(comp.format_optimiser_parameters((0.5, 500)),
|
||
|
"axa 0.500; axb 100.0")
|
||
|
|
||
|
tc.assertEqual(comp.transform_parameters((0.5, 1e6)), (0.5, 100.0))
|
||
|
with tc.assertRaises(CompetitionError) as ar:
|
||
|
comp.transform_parameters((0.5, None))
|
||
|
tc.assertTracebackStringEqual(str(ar.exception), dedent("""\
|
||
|
error from transform for axisb
|
||
|
TypeError: expected-float
|
||
|
traceback (most recent call last):
|
||
|
cem_tuner_tests|clip_axisb
|
||
|
failing line:
|
||
|
f = float(f)
|
||
|
"""), fixups=[
|
||
|
("float() argument must be a string or a number", "expected-float"),
|
||
|
("expected float, got NoneType object", "expected-float"),
|
||
|
])
|
||
|
|
||
|
tc.assertRaisesRegexp(
|
||
|
ValueError, "'initial_variance': must be nonnegative",
|
||
|
comp.parameter_spec_from_config,
|
||
|
Parameter_config('pa1', initial_mean=0,
|
||
|
initial_variance=-1, format="%s"))
|
||
|
tc.assertRaisesRegexp(
|
||
|
ControlFileError, "'format': invalid format string",
|
||
|
comp.parameter_spec_from_config,
|
||
|
Parameter_config('pa1', initial_mean=0, initial_variance=1,
|
||
|
format="nopct"))
|
||
|
|
||
|
def test_nonsense_parameter_config(tc):
|
||
|
comp = cem_tuners.Cem_tuner('cemtest')
|
||
|
config = default_config()
|
||
|
config['parameters'].append(99)
|
||
|
with tc.assertRaises(ControlFileError) as ar:
|
||
|
comp.initialise_from_control_file(config)
|
||
|
tc.assertMultiLineEqual(str(ar.exception), dedent("""\
|
||
|
'parameters': item 2: not a Parameter"""))
|
||
|
|
||
|
def test_transform_check(tc):
|
||
|
comp = cem_tuners.Cem_tuner('cemtest')
|
||
|
config = default_config()
|
||
|
config['parameters'][0] = Parameter_config(
|
||
|
'axisa',
|
||
|
initial_mean = 0.5,
|
||
|
initial_variance = 1.0,
|
||
|
transform = str.split)
|
||
|
with tc.assertRaises(ControlFileError) as ar:
|
||
|
comp.initialise_from_control_file(config)
|
||
|
tc.assertTracebackStringEqual(str(ar.exception), dedent("""\
|
||
|
parameter axisa: error from transform (applied to initial_mean)
|
||
|
TypeError: split-wants-float-not-str
|
||
|
traceback (most recent call last):
|
||
|
"""), fixups=[
|
||
|
("descriptor 'split' requires a 'str' object but received a 'float'",
|
||
|
"split-wants-float-not-str"),
|
||
|
("unbound method split() must be called with str instance as "
|
||
|
"first argument (got float instance instead)",
|
||
|
"split-wants-float-not-str"),
|
||
|
])
|
||
|
|
||
|
def test_format_validation(tc):
|
||
|
comp = cem_tuners.Cem_tuner('cemtest')
|
||
|
config = default_config()
|
||
|
config['parameters'][0] = Parameter_config(
|
||
|
'axisa',
|
||
|
initial_mean = 0.5,
|
||
|
initial_variance = 1.0,
|
||
|
transform = str,
|
||
|
format = "axa %f")
|
||
|
tc.assertRaisesRegexp(
|
||
|
ControlFileError, "'format': invalid format string",
|
||
|
comp.initialise_from_control_file, config)
|
||
|
|
||
|
def test_make_candidate(tc):
|
||
|
comp = cem_tuners.Cem_tuner('cemtest')
|
||
|
config = default_config()
|
||
|
comp.initialise_from_control_file(config)
|
||
|
cand = comp.make_candidate('g0#1', (0.5, 23.0))
|
||
|
tc.assertEqual(cand.code, 'g0#1')
|
||
|
tc.assertListEqual(cand.cmd_args, ['cand', '0.5', '23.0'])
|
||
|
with tc.assertRaises(CompetitionError) as ar:
|
||
|
comp.make_candidate('g0#1', (-1, 23))
|
||
|
tc.assertTracebackStringEqual(str(ar.exception), dedent("""\
|
||
|
error from make_candidate()
|
||
|
ValueError: oops
|
||
|
traceback (most recent call last):
|
||
|
cem_tuner_tests|simple_make_candidate
|
||
|
failing line:
|
||
|
raise ValueError("oops")
|
||
|
"""))
|
||
|
|
||
|
def test_play(tc):
|
||
|
comp = cem_tuners.Cem_tuner('cemtest')
|
||
|
comp.initialise_from_control_file(default_config())
|
||
|
comp.set_clean_status()
|
||
|
|
||
|
tc.assertEqual(comp.generation, 0)
|
||
|
tc.assertEqual(comp.distribution.format(),
|
||
|
" 0.50~1.00 50.00~1000.00")
|
||
|
|
||
|
job1 = comp.get_game()
|
||
|
tc.assertIsInstance(job1, Game_job)
|
||
|
tc.assertEqual(job1.game_id, 'g0#0r0')
|
||
|
tc.assertEqual(job1.player_b.code, 'g0#0')
|
||
|
tc.assertEqual(job1.player_w.code, 'opp')
|
||
|
tc.assertEqual(job1.board_size, 13)
|
||
|
tc.assertEqual(job1.komi, 7.5)
|
||
|
tc.assertEqual(job1.move_limit, 1000)
|
||
|
tc.assertIs(job1.use_internal_scorer, False)
|
||
|
tc.assertEqual(job1.internal_scorer_handicap_compensation, 'full')
|
||
|
tc.assertEqual(job1.game_data, (0, 'g0#0', 0))
|
||
|
tc.assertEqual(job1.sgf_event, 'cemtest')
|
||
|
tc.assertRegexpMatches(job1.sgf_note, '^Candidate parameters: axa ')
|
||
|
|
||
|
job2 = comp.get_game()
|
||
|
tc.assertIsInstance(job2, Game_job)
|
||
|
tc.assertEqual(job2.game_id, 'g0#1r0')
|
||
|
tc.assertEqual(job2.player_b.code, 'g0#1')
|
||
|
tc.assertEqual(job2.player_w.code, 'opp')
|
||
|
|
||
|
tc.assertEqual(comp.wins, [0, 0, 0, 0])
|
||
|
|
||
|
result1 = Game_result({'b' : 'g0#0', 'w' : 'opp'}, 'b')
|
||
|
result1.sgf_result = "B+8.5"
|
||
|
response1 = Game_job_result()
|
||
|
response1.game_id = job1.game_id
|
||
|
response1.game_result = result1
|
||
|
response1.engine_names = {
|
||
|
'opp' : 'opp engine:v1.2.3',
|
||
|
'g0#0' : 'candidate engine',
|
||
|
}
|
||
|
response1.engine_descriptions = {
|
||
|
'opp' : 'opp engine:v1.2.3',
|
||
|
'g0#0' : 'candidate engine description',
|
||
|
}
|
||
|
response1.game_data = job1.game_data
|
||
|
comp.process_game_result(response1)
|
||
|
|
||
|
tc.assertEqual(comp.wins, [1, 0, 0, 0])
|
||
|
|
||
|
comp2 = cem_tuners.Cem_tuner('cemtest')
|
||
|
comp2.initialise_from_control_file(default_config())
|
||
|
status = pickle.loads(pickle.dumps(comp.get_status()))
|
||
|
comp2.set_status(status)
|
||
|
tc.assertEqual(comp2.wins, [1, 0, 0, 0])
|
||
|
|
||
|
result2 = Game_result({'b' : 'g0#1', 'w' : 'opp'}, None)
|
||
|
result2.set_jigo()
|
||
|
response2 = Game_job_result()
|
||
|
response2.game_id = job2.game_id
|
||
|
response2.game_result = result2
|
||
|
response2.engine_names = response1.engine_names
|
||
|
response2.engine_descriptions = response1.engine_descriptions
|
||
|
response2.game_data = job2.game_data
|
||
|
comp.process_game_result(response2)
|
||
|
|
||
|
tc.assertEqual(comp.wins, [1, 0.5, 0, 0])
|
||
|
|