pygo/gomill/gomill_tests/cem_tuner_tests.py

228 lines
7.7 KiB
Python
Raw Normal View History

"""Tests for cem_tuners.py"""
from __future__ import with_statement, division
import cPickle as pickle
from textwrap import dedent
from gomill import cem_tuners
from gomill.game_jobs import Game_job, Game_job_result
from gomill.gtp_games import Game_result
from gomill.cem_tuners import Parameter_config
from gomill.competitions import (
Player_config, CompetitionError, ControlFileError)
from gomill_tests import gomill_test_support
def make_tests(suite):
suite.addTests(gomill_test_support.make_simple_tests(globals()))
def simple_make_candidate(*args):
if -1 in args:
raise ValueError("oops")
return Player_config("cand " + " ".join(map(str, args)))
def clip_axisb(f):
f = float(f)
return max(0.0, min(100.0, f))
def default_config():
return {
'board_size' : 13,
'komi' : 7.5,
'players' : {
'opp' : Player_config("test"),
},
'candidate_colour' : 'w',
'opponent' : 'opp',
'parameters' : [
Parameter_config(
'axisa',
initial_mean = 0.5,
initial_variance = 1.0,
format = "axa %.3f"),
Parameter_config(
'axisb',
initial_mean = 50.0,
initial_variance = 1000.0,
transform = clip_axisb,
format = "axb %.1f"),
],
'batch_size' : 3,
'samples_per_generation' : 4,
'number_of_generations' : 3,
'elite_proportion' : 0.1,
'step_size' : 0.8,
'make_candidate' : simple_make_candidate,
}
def test_parameter_config(tc):
comp = cem_tuners.Cem_tuner('cemtest')
config = default_config()
comp.initialise_from_control_file(config)
tc.assertEqual(comp.initial_distribution.format(),
" 0.50~1.00 50.00~1000.00")
tc.assertEqual(comp.format_engine_parameters((0.5, 23)),
"axa 0.500; axb 23.0")
tc.assertEqual(comp.format_engine_parameters(('x', 23)),
"[axisa?x]; axb 23.0")
tc.assertEqual(comp.format_optimiser_parameters((0.5, 500)),
"axa 0.500; axb 100.0")
tc.assertEqual(comp.transform_parameters((0.5, 1e6)), (0.5, 100.0))
with tc.assertRaises(CompetitionError) as ar:
comp.transform_parameters((0.5, None))
tc.assertTracebackStringEqual(str(ar.exception), dedent("""\
error from transform for axisb
TypeError: expected-float
traceback (most recent call last):
cem_tuner_tests|clip_axisb
failing line:
f = float(f)
"""), fixups=[
("float() argument must be a string or a number", "expected-float"),
("expected float, got NoneType object", "expected-float"),
])
tc.assertRaisesRegexp(
ValueError, "'initial_variance': must be nonnegative",
comp.parameter_spec_from_config,
Parameter_config('pa1', initial_mean=0,
initial_variance=-1, format="%s"))
tc.assertRaisesRegexp(
ControlFileError, "'format': invalid format string",
comp.parameter_spec_from_config,
Parameter_config('pa1', initial_mean=0, initial_variance=1,
format="nopct"))
def test_nonsense_parameter_config(tc):
comp = cem_tuners.Cem_tuner('cemtest')
config = default_config()
config['parameters'].append(99)
with tc.assertRaises(ControlFileError) as ar:
comp.initialise_from_control_file(config)
tc.assertMultiLineEqual(str(ar.exception), dedent("""\
'parameters': item 2: not a Parameter"""))
def test_transform_check(tc):
comp = cem_tuners.Cem_tuner('cemtest')
config = default_config()
config['parameters'][0] = Parameter_config(
'axisa',
initial_mean = 0.5,
initial_variance = 1.0,
transform = str.split)
with tc.assertRaises(ControlFileError) as ar:
comp.initialise_from_control_file(config)
tc.assertTracebackStringEqual(str(ar.exception), dedent("""\
parameter axisa: error from transform (applied to initial_mean)
TypeError: split-wants-float-not-str
traceback (most recent call last):
"""), fixups=[
("descriptor 'split' requires a 'str' object but received a 'float'",
"split-wants-float-not-str"),
("unbound method split() must be called with str instance as "
"first argument (got float instance instead)",
"split-wants-float-not-str"),
])
def test_format_validation(tc):
comp = cem_tuners.Cem_tuner('cemtest')
config = default_config()
config['parameters'][0] = Parameter_config(
'axisa',
initial_mean = 0.5,
initial_variance = 1.0,
transform = str,
format = "axa %f")
tc.assertRaisesRegexp(
ControlFileError, "'format': invalid format string",
comp.initialise_from_control_file, config)
def test_make_candidate(tc):
comp = cem_tuners.Cem_tuner('cemtest')
config = default_config()
comp.initialise_from_control_file(config)
cand = comp.make_candidate('g0#1', (0.5, 23.0))
tc.assertEqual(cand.code, 'g0#1')
tc.assertListEqual(cand.cmd_args, ['cand', '0.5', '23.0'])
with tc.assertRaises(CompetitionError) as ar:
comp.make_candidate('g0#1', (-1, 23))
tc.assertTracebackStringEqual(str(ar.exception), dedent("""\
error from make_candidate()
ValueError: oops
traceback (most recent call last):
cem_tuner_tests|simple_make_candidate
failing line:
raise ValueError("oops")
"""))
def test_play(tc):
comp = cem_tuners.Cem_tuner('cemtest')
comp.initialise_from_control_file(default_config())
comp.set_clean_status()
tc.assertEqual(comp.generation, 0)
tc.assertEqual(comp.distribution.format(),
" 0.50~1.00 50.00~1000.00")
job1 = comp.get_game()
tc.assertIsInstance(job1, Game_job)
tc.assertEqual(job1.game_id, 'g0#0r0')
tc.assertEqual(job1.player_b.code, 'g0#0')
tc.assertEqual(job1.player_w.code, 'opp')
tc.assertEqual(job1.board_size, 13)
tc.assertEqual(job1.komi, 7.5)
tc.assertEqual(job1.move_limit, 1000)
tc.assertIs(job1.use_internal_scorer, False)
tc.assertEqual(job1.internal_scorer_handicap_compensation, 'full')
tc.assertEqual(job1.game_data, (0, 'g0#0', 0))
tc.assertEqual(job1.sgf_event, 'cemtest')
tc.assertRegexpMatches(job1.sgf_note, '^Candidate parameters: axa ')
job2 = comp.get_game()
tc.assertIsInstance(job2, Game_job)
tc.assertEqual(job2.game_id, 'g0#1r0')
tc.assertEqual(job2.player_b.code, 'g0#1')
tc.assertEqual(job2.player_w.code, 'opp')
tc.assertEqual(comp.wins, [0, 0, 0, 0])
result1 = Game_result({'b' : 'g0#0', 'w' : 'opp'}, 'b')
result1.sgf_result = "B+8.5"
response1 = Game_job_result()
response1.game_id = job1.game_id
response1.game_result = result1
response1.engine_names = {
'opp' : 'opp engine:v1.2.3',
'g0#0' : 'candidate engine',
}
response1.engine_descriptions = {
'opp' : 'opp engine:v1.2.3',
'g0#0' : 'candidate engine description',
}
response1.game_data = job1.game_data
comp.process_game_result(response1)
tc.assertEqual(comp.wins, [1, 0, 0, 0])
comp2 = cem_tuners.Cem_tuner('cemtest')
comp2.initialise_from_control_file(default_config())
status = pickle.loads(pickle.dumps(comp.get_status()))
comp2.set_status(status)
tc.assertEqual(comp2.wins, [1, 0, 0, 0])
result2 = Game_result({'b' : 'g0#1', 'w' : 'opp'}, None)
result2.set_jigo()
response2 = Game_job_result()
response2.game_id = job2.game_id
response2.game_result = result2
response2.engine_names = response1.engine_names
response2.engine_descriptions = response1.engine_descriptions
response2.game_data = job2.game_data
comp.process_game_result(response2)
tc.assertEqual(comp.wins, [1, 0.5, 0, 0])