call.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. // Copyright 2010 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package gomock
  15. import (
  16. "fmt"
  17. "reflect"
  18. "strconv"
  19. "strings"
  20. )
  21. // Call represents an expected call to a mock.
  22. type Call struct {
  23. t TestHelper // for triggering test failures on invalid call setup
  24. receiver interface{} // the receiver of the method call
  25. method string // the name of the method
  26. methodType reflect.Type // the type of the method
  27. args []Matcher // the args
  28. origin string // file and line number of call setup
  29. preReqs []*Call // prerequisite calls
  30. // Expectations
  31. minCalls, maxCalls int
  32. numCalls int // actual number made
  33. // actions are called when this Call is called. Each action gets the args and
  34. // can set the return values by returning a non-nil slice. Actions run in the
  35. // order they are created.
  36. actions []func([]interface{}) []interface{}
  37. }
  38. // newCall creates a *Call. It requires the method type in order to support
  39. // unexported methods.
  40. func newCall(t TestHelper, receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
  41. t.Helper()
  42. // TODO: check arity, types.
  43. mArgs := make([]Matcher, len(args))
  44. for i, arg := range args {
  45. if m, ok := arg.(Matcher); ok {
  46. mArgs[i] = m
  47. } else if arg == nil {
  48. // Handle nil specially so that passing a nil interface value
  49. // will match the typed nils of concrete args.
  50. mArgs[i] = Nil()
  51. } else {
  52. mArgs[i] = Eq(arg)
  53. }
  54. }
  55. // callerInfo's skip should be updated if the number of calls between the user's test
  56. // and this line changes, i.e. this code is wrapped in another anonymous function.
  57. // 0 is us, 1 is RecordCallWithMethodType(), 2 is the generated recorder, and 3 is the user's test.
  58. origin := callerInfo(3)
  59. actions := []func([]interface{}) []interface{}{func([]interface{}) []interface{} {
  60. // Synthesize the zero value for each of the return args' types.
  61. rets := make([]interface{}, methodType.NumOut())
  62. for i := 0; i < methodType.NumOut(); i++ {
  63. rets[i] = reflect.Zero(methodType.Out(i)).Interface()
  64. }
  65. return rets
  66. }}
  67. return &Call{t: t, receiver: receiver, method: method, methodType: methodType,
  68. args: mArgs, origin: origin, minCalls: 1, maxCalls: 1, actions: actions}
  69. }
  70. // AnyTimes allows the expectation to be called 0 or more times
  71. func (c *Call) AnyTimes() *Call {
  72. c.minCalls, c.maxCalls = 0, 1e8 // close enough to infinity
  73. return c
  74. }
  75. // MinTimes requires the call to occur at least n times. If AnyTimes or MaxTimes have not been called or if MaxTimes
  76. // was previously called with 1, MinTimes also sets the maximum number of calls to infinity.
  77. func (c *Call) MinTimes(n int) *Call {
  78. c.minCalls = n
  79. if c.maxCalls == 1 {
  80. c.maxCalls = 1e8
  81. }
  82. return c
  83. }
  84. // MaxTimes limits the number of calls to n times. If AnyTimes or MinTimes have not been called or if MinTimes was
  85. // previously called with 1, MaxTimes also sets the minimum number of calls to 0.
  86. func (c *Call) MaxTimes(n int) *Call {
  87. c.maxCalls = n
  88. if c.minCalls == 1 {
  89. c.minCalls = 0
  90. }
  91. return c
  92. }
  93. // DoAndReturn declares the action to run when the call is matched.
  94. // The return values from this function are returned by the mocked function.
  95. // It takes an interface{} argument to support n-arity functions.
  96. func (c *Call) DoAndReturn(f interface{}) *Call {
  97. // TODO: Check arity and types here, rather than dying badly elsewhere.
  98. v := reflect.ValueOf(f)
  99. c.addAction(func(args []interface{}) []interface{} {
  100. c.t.Helper()
  101. vArgs := make([]reflect.Value, len(args))
  102. ft := v.Type()
  103. if c.methodType.NumIn() != ft.NumIn() {
  104. c.t.Fatalf("wrong number of arguments in DoAndReturn func for %T.%v: got %d, want %d [%s]",
  105. c.receiver, c.method, ft.NumIn(), c.methodType.NumIn(), c.origin)
  106. return nil
  107. }
  108. for i := 0; i < len(args); i++ {
  109. if args[i] != nil {
  110. vArgs[i] = reflect.ValueOf(args[i])
  111. } else {
  112. // Use the zero value for the arg.
  113. vArgs[i] = reflect.Zero(ft.In(i))
  114. }
  115. }
  116. vRets := v.Call(vArgs)
  117. rets := make([]interface{}, len(vRets))
  118. for i, ret := range vRets {
  119. rets[i] = ret.Interface()
  120. }
  121. return rets
  122. })
  123. return c
  124. }
  125. // Do declares the action to run when the call is matched. The function's
  126. // return values are ignored to retain backward compatibility. To use the
  127. // return values call DoAndReturn.
  128. // It takes an interface{} argument to support n-arity functions.
  129. func (c *Call) Do(f interface{}) *Call {
  130. // TODO: Check arity and types here, rather than dying badly elsewhere.
  131. v := reflect.ValueOf(f)
  132. c.addAction(func(args []interface{}) []interface{} {
  133. c.t.Helper()
  134. if c.methodType.NumIn() != v.Type().NumIn() {
  135. c.t.Fatalf("wrong number of arguments in Do func for %T.%v: got %d, want %d [%s]",
  136. c.receiver, c.method, v.Type().NumIn(), c.methodType.NumIn(), c.origin)
  137. return nil
  138. }
  139. vArgs := make([]reflect.Value, len(args))
  140. ft := v.Type()
  141. for i := 0; i < len(args); i++ {
  142. if args[i] != nil {
  143. vArgs[i] = reflect.ValueOf(args[i])
  144. } else {
  145. // Use the zero value for the arg.
  146. vArgs[i] = reflect.Zero(ft.In(i))
  147. }
  148. }
  149. v.Call(vArgs)
  150. return nil
  151. })
  152. return c
  153. }
  154. // Return declares the values to be returned by the mocked function call.
  155. func (c *Call) Return(rets ...interface{}) *Call {
  156. c.t.Helper()
  157. mt := c.methodType
  158. if len(rets) != mt.NumOut() {
  159. c.t.Fatalf("wrong number of arguments to Return for %T.%v: got %d, want %d [%s]",
  160. c.receiver, c.method, len(rets), mt.NumOut(), c.origin)
  161. }
  162. for i, ret := range rets {
  163. if got, want := reflect.TypeOf(ret), mt.Out(i); got == want {
  164. // Identical types; nothing to do.
  165. } else if got == nil {
  166. // Nil needs special handling.
  167. switch want.Kind() {
  168. case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
  169. // ok
  170. default:
  171. c.t.Fatalf("argument %d to Return for %T.%v is nil, but %v is not nillable [%s]",
  172. i, c.receiver, c.method, want, c.origin)
  173. }
  174. } else if got.AssignableTo(want) {
  175. // Assignable type relation. Make the assignment now so that the generated code
  176. // can return the values with a type assertion.
  177. v := reflect.New(want).Elem()
  178. v.Set(reflect.ValueOf(ret))
  179. rets[i] = v.Interface()
  180. } else {
  181. c.t.Fatalf("wrong type of argument %d to Return for %T.%v: %v is not assignable to %v [%s]",
  182. i, c.receiver, c.method, got, want, c.origin)
  183. }
  184. }
  185. c.addAction(func([]interface{}) []interface{} {
  186. return rets
  187. })
  188. return c
  189. }
  190. // Times declares the exact number of times a function call is expected to be executed.
  191. func (c *Call) Times(n int) *Call {
  192. c.minCalls, c.maxCalls = n, n
  193. return c
  194. }
  195. // SetArg declares an action that will set the nth argument's value,
  196. // indirected through a pointer. Or, in the case of a slice, SetArg
  197. // will copy value's elements into the nth argument.
  198. func (c *Call) SetArg(n int, value interface{}) *Call {
  199. c.t.Helper()
  200. mt := c.methodType
  201. // TODO: This will break on variadic methods.
  202. // We will need to check those at invocation time.
  203. if n < 0 || n >= mt.NumIn() {
  204. c.t.Fatalf("SetArg(%d, ...) called for a method with %d args [%s]",
  205. n, mt.NumIn(), c.origin)
  206. }
  207. // Permit setting argument through an interface.
  208. // In the interface case, we don't (nay, can't) check the type here.
  209. at := mt.In(n)
  210. switch at.Kind() {
  211. case reflect.Ptr:
  212. dt := at.Elem()
  213. if vt := reflect.TypeOf(value); !vt.AssignableTo(dt) {
  214. c.t.Fatalf("SetArg(%d, ...) argument is a %v, not assignable to %v [%s]",
  215. n, vt, dt, c.origin)
  216. }
  217. case reflect.Interface:
  218. // nothing to do
  219. case reflect.Slice:
  220. // nothing to do
  221. default:
  222. c.t.Fatalf("SetArg(%d, ...) referring to argument of non-pointer non-interface non-slice type %v [%s]",
  223. n, at, c.origin)
  224. }
  225. c.addAction(func(args []interface{}) []interface{} {
  226. v := reflect.ValueOf(value)
  227. switch reflect.TypeOf(args[n]).Kind() {
  228. case reflect.Slice:
  229. setSlice(args[n], v)
  230. default:
  231. reflect.ValueOf(args[n]).Elem().Set(v)
  232. }
  233. return nil
  234. })
  235. return c
  236. }
  237. // isPreReq returns true if other is a direct or indirect prerequisite to c.
  238. func (c *Call) isPreReq(other *Call) bool {
  239. for _, preReq := range c.preReqs {
  240. if other == preReq || preReq.isPreReq(other) {
  241. return true
  242. }
  243. }
  244. return false
  245. }
  246. // After declares that the call may only match after preReq has been exhausted.
  247. func (c *Call) After(preReq *Call) *Call {
  248. c.t.Helper()
  249. if c == preReq {
  250. c.t.Fatalf("A call isn't allowed to be its own prerequisite")
  251. }
  252. if preReq.isPreReq(c) {
  253. c.t.Fatalf("Loop in call order: %v is a prerequisite to %v (possibly indirectly).", c, preReq)
  254. }
  255. c.preReqs = append(c.preReqs, preReq)
  256. return c
  257. }
  258. // Returns true if the minimum number of calls have been made.
  259. func (c *Call) satisfied() bool {
  260. return c.numCalls >= c.minCalls
  261. }
  262. // Returns true if the maximum number of calls have been made.
  263. func (c *Call) exhausted() bool {
  264. return c.numCalls >= c.maxCalls
  265. }
  266. func (c *Call) String() string {
  267. args := make([]string, len(c.args))
  268. for i, arg := range c.args {
  269. args[i] = arg.String()
  270. }
  271. arguments := strings.Join(args, ", ")
  272. return fmt.Sprintf("%T.%v(%s) %s", c.receiver, c.method, arguments, c.origin)
  273. }
  274. // Tests if the given call matches the expected call.
  275. // If yes, returns nil. If no, returns error with message explaining why it does not match.
  276. func (c *Call) matches(args []interface{}) error {
  277. if !c.methodType.IsVariadic() {
  278. if len(args) != len(c.args) {
  279. return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d",
  280. c.origin, len(args), len(c.args))
  281. }
  282. for i, m := range c.args {
  283. if !m.Matches(args[i]) {
  284. return fmt.Errorf(
  285. "expected call at %s doesn't match the argument at index %d.\nGot: %v\nWant: %v",
  286. c.origin, i, formatGottenArg(m, args[i]), m,
  287. )
  288. }
  289. }
  290. } else {
  291. if len(c.args) < c.methodType.NumIn()-1 {
  292. return fmt.Errorf("expected call at %s has the wrong number of matchers. Got: %d, want: %d",
  293. c.origin, len(c.args), c.methodType.NumIn()-1)
  294. }
  295. if len(c.args) != c.methodType.NumIn() && len(args) != len(c.args) {
  296. return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: %d",
  297. c.origin, len(args), len(c.args))
  298. }
  299. if len(args) < len(c.args)-1 {
  300. return fmt.Errorf("expected call at %s has the wrong number of arguments. Got: %d, want: greater than or equal to %d",
  301. c.origin, len(args), len(c.args)-1)
  302. }
  303. for i, m := range c.args {
  304. if i < c.methodType.NumIn()-1 {
  305. // Non-variadic args
  306. if !m.Matches(args[i]) {
  307. return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
  308. c.origin, strconv.Itoa(i), formatGottenArg(m, args[i]), m)
  309. }
  310. continue
  311. }
  312. // The last arg has a possibility of a variadic argument, so let it branch
  313. // sample: Foo(a int, b int, c ...int)
  314. if i < len(c.args) && i < len(args) {
  315. if m.Matches(args[i]) {
  316. // Got Foo(a, b, c) want Foo(matcherA, matcherB, gomock.Any())
  317. // Got Foo(a, b, c) want Foo(matcherA, matcherB, someSliceMatcher)
  318. // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC)
  319. // Got Foo(a, b) want Foo(matcherA, matcherB)
  320. // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD)
  321. continue
  322. }
  323. }
  324. // The number of actual args don't match the number of matchers,
  325. // or the last matcher is a slice and the last arg is not.
  326. // If this function still matches it is because the last matcher
  327. // matches all the remaining arguments or the lack of any.
  328. // Convert the remaining arguments, if any, into a slice of the
  329. // expected type.
  330. vArgsType := c.methodType.In(c.methodType.NumIn() - 1)
  331. vArgs := reflect.MakeSlice(vArgsType, 0, len(args)-i)
  332. for _, arg := range args[i:] {
  333. vArgs = reflect.Append(vArgs, reflect.ValueOf(arg))
  334. }
  335. if m.Matches(vArgs.Interface()) {
  336. // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, gomock.Any())
  337. // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, someSliceMatcher)
  338. // Got Foo(a, b) want Foo(matcherA, matcherB, gomock.Any())
  339. // Got Foo(a, b) want Foo(matcherA, matcherB, someEmptySliceMatcher)
  340. break
  341. }
  342. // Wrong number of matchers or not match. Fail.
  343. // Got Foo(a, b) want Foo(matcherA, matcherB, matcherC, matcherD)
  344. // Got Foo(a, b, c) want Foo(matcherA, matcherB, matcherC, matcherD)
  345. // Got Foo(a, b, c, d) want Foo(matcherA, matcherB, matcherC, matcherD, matcherE)
  346. // Got Foo(a, b, c, d, e) want Foo(matcherA, matcherB, matcherC, matcherD)
  347. // Got Foo(a, b, c) want Foo(matcherA, matcherB)
  348. return fmt.Errorf("expected call at %s doesn't match the argument at index %s.\nGot: %v\nWant: %v",
  349. c.origin, strconv.Itoa(i), formatGottenArg(m, args[i:]), c.args[i])
  350. }
  351. }
  352. // Check that all prerequisite calls have been satisfied.
  353. for _, preReqCall := range c.preReqs {
  354. if !preReqCall.satisfied() {
  355. return fmt.Errorf("expected call at %s doesn't have a prerequisite call satisfied:\n%v\nshould be called before:\n%v",
  356. c.origin, preReqCall, c)
  357. }
  358. }
  359. // Check that the call is not exhausted.
  360. if c.exhausted() {
  361. return fmt.Errorf("expected call at %s has already been called the max number of times", c.origin)
  362. }
  363. return nil
  364. }
  365. // dropPrereqs tells the expected Call to not re-check prerequisite calls any
  366. // longer, and to return its current set.
  367. func (c *Call) dropPrereqs() (preReqs []*Call) {
  368. preReqs = c.preReqs
  369. c.preReqs = nil
  370. return
  371. }
  372. func (c *Call) call() []func([]interface{}) []interface{} {
  373. c.numCalls++
  374. return c.actions
  375. }
  376. // InOrder declares that the given calls should occur in order.
  377. func InOrder(calls ...*Call) {
  378. for i := 1; i < len(calls); i++ {
  379. calls[i].After(calls[i-1])
  380. }
  381. }
  382. func setSlice(arg interface{}, v reflect.Value) {
  383. va := reflect.ValueOf(arg)
  384. for i := 0; i < v.Len(); i++ {
  385. va.Index(i).Set(v.Index(i))
  386. }
  387. }
  388. func (c *Call) addAction(action func([]interface{}) []interface{}) {
  389. c.actions = append(c.actions, action)
  390. }
  391. func formatGottenArg(m Matcher, arg interface{}) string {
  392. got := fmt.Sprintf("%v (%T)", arg, arg)
  393. if gs, ok := m.(GotFormatter); ok {
  394. got = gs.Got(arg)
  395. }
  396. return got
  397. }