forked from dmlc/xgboost
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_param_array.cc
More file actions
73 lines (67 loc) · 2.1 KB
/
test_param_array.cc
File metadata and controls
73 lines (67 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
/**
* Copyright 2025, XGBoost contributors
*/
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <xgboost/base.h> // for kRtEps
#include <xgboost/json.h> // for Json
#include <xgboost/parameter.h> // for XGBoostParameter
#include <xgboost/string_view.h> // for StringView
#include <sstream> // for istringstream, ostringstream
#include <string> // for string
#include "../../../src/common/param_array.h"
#include "../helpers.h"
namespace xgboost::common {
TEST(ParamArray, Float) {
ParamArray<float> values{"values"};
{
std::istringstream sin{"1.1"};
sin >> values;
ASSERT_EQ(values.size(), 1);
ASSERT_NEAR(values[0], 1.1, kRtEps);
std::ostringstream sout;
sout << values;
auto jarr = Json::Load(StringView{sout.str()});
for (std::size_t i = 0; i < values.size(); ++i) {
ASSERT_EQ(get<Number const>(jarr[i]), values[i]);
}
}
{
std::string str = "[1.1, 1.3]";
std::istringstream sin{str};
sin >> values;
ASSERT_EQ(values.size(), 2);
ASSERT_NEAR(values[0], 1.1, kRtEps);
ASSERT_NEAR(values[1], 1.3, kRtEps);
std::ostringstream sout;
sout << values;
auto jarr = Json::Load(StringView{sout.str()});
for (std::size_t i = 0; i < values.size(); ++i) {
ASSERT_EQ(get<Number const>(jarr[i]), values[i]);
}
}
{
ParamArray<float> values{"values"};
std::istringstream sin{"[\"foo\"]"};
ASSERT_THAT(
[&] { sin >> values; },
GMockThrow(
R"(Invalid type for: `values`, expecting one of the: {`Number`, `Integer`}, got: `String`)"));
}
}
namespace {
struct TestParamArray : public XGBoostParameter<TestParamArray> {
ParamArray<float> test_key{"test_key", 0.2f};
DMLC_DECLARE_PARAMETER(TestParamArray) {
DMLC_DECLARE_FIELD(test_key).describe("test").set_default(ParamArray<float>{"test_key", 0.2f});
}
};
DMLC_REGISTER_PARAMETER(TestParamArray);
} // namespace
TEST(ParamArray, Update) {
TestParamArray param;
param.UpdateAllowUnknown(Args{{}});
ASSERT_EQ(param.test_key.size(), 1);
ASSERT_EQ(param.test_key.Name(), "test_key");
}
} // namespace xgboost::common